Unverified Commit da3890e8 authored by SijiaYang's avatar SijiaYang Committed by GitHub
Browse files

[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)


Signed-off-by: default avataryangsijia.614 <yangsijia.614@bytedance.com>
Co-authored-by: default avataryicwang <yichen.wang@bytedance.com>
parent cb432f17
...@@ -249,6 +249,9 @@ set(SOURCES ...@@ -249,6 +249,9 @@ set(SOURCES
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu" "csrc/kvcacheio/transfer.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu" "csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
......
...@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"int num_layers) -> ()"); "int num_layers) -> ()");
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/* /*
* From FlashInfer * From FlashInfer
*/ */
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/copy_sm90.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/util/type_traits.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective::detail {
template <class Collective>
struct MixedGroupedGemmInputUtils {
private:
using KernelSchedule = typename Collective::KernelSchedule;
using ConversionMode = typename Collective::ConversionMode;
using SmemLayoutA = typename Collective::SmemLayoutA;
using SmemLayoutB = typename Collective::SmemLayoutB;
using SmemLayoutScale = typename Collective::SmemLayoutScale;
using SwappedElementA = typename Collective::SwappedElementA;
using SwappedElementB = typename Collective::SwappedElementB;
using RealSwappedElementA = typename Collective::RealSwappedElementA;
using RealSwappedElementB = typename Collective::RealSwappedElementB;
using ElementScale = typename Collective::ElementScale;
using ElementZero = typename Collective::ElementZero;
using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale;
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
static constexpr auto ModeHasScales = Collective::ModeHasScales;
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
public:
static constexpr auto elements_per_smem_scale() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
}
}
static constexpr auto elements_per_smem_zero() {
if constexpr (
KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale) {
return 0;
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in scale smem allocation.");
}
}
// These methods use some the public members of the class. For that reason, we define them after the public section.
static constexpr uint32_t compute_tma_transaction_bytes_mk() {
return cutlass::bits_to_bytes(
size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementA>));
}
static constexpr uint32_t compute_tma_transaction_bytes_nk() {
return cutlass::bits_to_bytes(
size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<SwappedElementB>));
}
static constexpr uint32_t compute_tma_transaction_bytes_extra() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return scale_tx_bytes;
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// Scale and zero share smem layout
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(
size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) *
static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
return scale_tx_bytes + zero_tx_bytes;
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
/// Utilities to copy A and extra inputs from smem to RF
template <class SmemTiledCopyA, class TensorASmemView, class TensorACopyView, class... Ts, class... Us>
CUTLASS_DEVICE static void copy_tensors_MK(
SmemTiledCopyA const& smem_tiled_copy_A,
TensorASmemView const& tCsA,
TensorACopyView& tCrA_copy_view,
cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views,
int k_block,
int read_stage) {
copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block));
if (k_block == 0) {
// We are starting a new k-tile so copy the scale
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block));
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
}
// The core converter uses a lookup table to converts i4 -> 8 bit value.
template <
class EngineIn,
class LayoutIn,
class EngineOut,
class LayoutOut,
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE static void lookup_table_convert( // Accept mutable temporaries
Tensor<EngineIn, LayoutIn> const& src,
Tensor<EngineOut, LayoutOut>&& dst,
Tensor<EngineScale, LayoutScale> const& scales_neg,
Tensor<EngineScale, LayoutScale> const& scales_pos) {
lookup_table_convert(src, dst, scales_neg, scales_pos);
}
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class EngineScale, class LayoutScale>
CUTLASS_DEVICE static void lookup_table_convert(
Tensor<EngineIn, LayoutIn> const& src,
Tensor<EngineOut, LayoutOut>& dst,
Tensor<EngineScale, LayoutScale> const& scales_neg,
Tensor<EngineScale, LayoutScale> const& scales_pos) {
constexpr int N = cute::cosize(LayoutIn{});
static_assert(N == 4 || N == 8);
static_assert(cosize(LayoutScale{}) <= N / 4, "at least 4 consecutive weights must share the same scale.");
using SrcArray = cutlass::Array<cutlass::int4b_t, 8>;
using DstArray = cutlass::Array<RealSwappedElementB, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, N / 4, sizeof(DstArray)>;
// View the input as reg
auto&& src_reg = cute::recast<uint32_t>(src)(0);
auto&& r = cute::recast<RegArray>(dst)(0);
// Determines if to get from the signed or unsigned candidates
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1
asm volatile(
"{\n"
" lop3.b32 %0, %1, %2, %3, %4;\n"
"}\n"
: "=r"(sign)
: "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut));
sign = sign >> 1;
// Ignore sign bit when indexing into LUT
uint32_t lut_idx = src_reg & 0x77777777;
Tensor scales_neg_ = cute::filter(scales_neg);
Tensor scales_pos_ = cute::filter(scales_pos);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i, lut_idx >>= 16, sign >>= 16) {
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_(i));
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_pos_(i));
asm volatile(
"{\n"
" .reg .b32 pos, neg ;\n"
" prmt .b32 neg, %3, %4, %1 ;\n"
" prmt .b32 pos, %5, %6, %1 ;\n"
" prmt .b32 %0, pos, neg, %2 ;\n"
"}\n"
: "=r"(r[i])
: "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]));
}
}
/// Utilities to dequantize A.
template <class Layout>
CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) {
static_assert(
shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0,
"At least 4 adjacent weights in a thread must share the same scale.");
}
template <class Engine, class Layout>
CUTLASS_DEVICE static void static_check_scale(Tensor<Engine, Layout> const& tensor) {
static_check_scale(flatten(Layout{}));
}
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock(
Tensor<EngineIn, LayoutIn> const& tCrA_load,
Tensor<EngineOut, LayoutOut>& tCrA_mma,
cute::tuple<Ts...>& partitioned_extra_info,
int const k_block) {
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
Tensor src = tCrA_load(_, _, k_block);
Tensor dst = tCrA_mma(_, _, k_block);
CUTE_STATIC_ASSERT_V(
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
// try to make the size of the first mode equal to 32bit
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
}
} else if constexpr (UseScaleLookupTable) {
constexpr int num_elements = decltype(size(src))::value;
static_assert(
is_same_v<RealSwappedElementA, cutlass::int4b_t>,
"Lookup table only supports int4 being the quant type now.");
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
static_assert(
num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
Tensor tCrS_neg = cute::get<1>(partitioned_extra_info);
auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed
Tensor scales_neg = tCrS_neg(_, _, k_block);
Tensor scales_pos = tCrS_pos(_, _, k_block);
CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg));
static_check_scale(scales_neg);
static_check_scale(scales_pos);
Tensor scales_neg_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_neg, Int<NumValPerSrcReg>{}));
Tensor scales_pos_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales_pos, Int<NumValPerSrcReg>{}));
if (k_block == 0) {
Tensor scales_neg_vm_ = filter(scales_neg_vm);
Tensor scales_pos_vm_ = filter(scales_pos_vm);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) {
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2>&>(scales_pos_vm_(i));
constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
asm volatile(
"{\n"
" lop3 .b32 %0, %2, %4, %5, %6;\n"
" xor .b32 %1, %3, %5; \n"
"}\n"
: "=r"(scale_pos_[0]), "=r"(scale_pos_[1])
: "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut));
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i));
}
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
if constexpr (is_same_v<DstType, ElementScale>) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
dst_vm(j, i) *= scales_vm(j, i);
}
}
} else {
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), stage);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
stage(j) *= scales_vm(j, i);
}
LayoutAwareConvert(stage, dst_vm(_, i));
}
}
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block);
Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, Int<NumValPerSrcReg>{}));
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, Int<NumValPerSrcReg>{}));
if constexpr (is_same_v<DstType, ElementScale>) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i);
}
}
} else {
auto stage = make_tensor_like<ElementScale>(src_vm(_, 0));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), stage);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(dst_vm); ++j) {
stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i);
}
LayoutAwareConvert(stage, dst_vm(_, i));
}
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "No A data is loaded.");
}
}
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void convert_A_kblock(
Tensor<EngineIn, LayoutIn> const& tCrA_load, Tensor<EngineOut, LayoutOut>& tCrA_mma, int const k_block) {
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
Tensor src = tCrA_load(_, _, k_block);
Tensor dst = tCrA_mma(_, _, k_block);
CUTE_STATIC_ASSERT_V(
size(src(_, 0)) == cosize(src(_, 0).layout()), "The first mode of tensor src must be contiguous in memory");
// try to make the size of the first mode equal to 32bit
int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, ceil_div(32, sizeof_bits_v<SrcType>));
Tensor src_vm = cute::group_modes<1, -1>(cute::zipped_divide(src, Int<NumValPerSrcReg>{}));
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, Int<NumValPerSrcReg>{}));
// KernelConversionMode == ConversionMode::DirectConvert
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<1>(dst_vm); ++i) {
LayoutAwareConvert(src_vm(_, i), dst_vm(_, i));
}
}
/// Utilities for any additional inputs inside of the TMA load
template <class Params, class TensorStorage, class... Ts>
CUTLASS_DEVICE static auto partition_extra_tma_inputs(
Params const& mainloop_params,
cute::tuple<Ts...> const& load_inputs,
TensorStorage& shared_tensors,
uint2 const& cluster_local_block_id,
int const m_coord,
int const l_coord) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple();
} else if constexpr (ModeHasScales) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gS_mkl = get<2>(load_inputs);
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tSgS, tSsS);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ =
make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gZ_mkl = get<3>(load_inputs);
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
template <class ThreadMma, class TensorStorage>
CUTLASS_DEVICE static auto
partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
} else if constexpr (UseScaleLookupTable) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
return cute::make_tuple(tCsS, tCrS);
} else if constexpr (ModeHasScales) {
Tensor sS =
make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).layout());
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tCsS, tCrS);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(
make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).layout());
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
/// Returns the tiled copy and copy views for the extra inputs.
template <class TiledMma, class... Ts>
CUTLASS_DEVICE static auto retile_extra_mma_info(
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info, int const warp_group_thread_idx) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
};
} // namespace cutlass::gemm::collective::detail
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/pipeline/sm90_pipeline.hpp"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_RS
template <
class ElementA_,
class GmemLayoutATag_,
int AlignmentA,
class ElementB_,
class GmemLayoutBTag_,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType>
struct CollectiveBuilderMixedInput<
arch::Sm90,
arch::OpClassTensorOp,
ElementA_,
GmemLayoutATag_,
AlignmentA,
ElementB_,
GmemLayoutBTag_,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
// ConvertAndScale and ConvertAndScaleWithZero
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
// DirectConvert
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
private:
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
// Determine if mixed input types.
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>;
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
public:
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
static_assert(
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
template <class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
return t;
} else {
if constexpr (cute::is_pointer_v<T>) {
return &cute::stride(*t);
} else {
return cute::stride(t);
}
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA =
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
using ElementPairB =
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
// operands.
static constexpr bool SwapAB =
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
static constexpr bool IsWarpSpecializedTransposeB =
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
// For fp32 types, map to tf32 MMA value type.
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
// Handle mixed dtypes and MMA.
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
// Always the same for element B.
using RealElementBMma = RealElementB;
static_assert(
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
"Mixed input GEMM does not support MN major layout except for 16bit");
using AtomLayoutMNK = cute::conditional_t<
cute::is_any_of_v<
KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>,
Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<
RealElementAMma,
RealElementBMma,
ElementAccumulator,
TileShape_MNK,
GMMA::Major::K,
GMMA::Major::K>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
GmmaMajorA,
ElementAMma,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
GmmaMajorB,
ElementBMma,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
// Handle mixed dtype array GEMM's size of tensor map storage.
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int PipelineStages =
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
Sm90ReducedSmemCapacityBytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{})
: detail::compute_stage_count_or_override_single_affine_transformed_input<
detail::sm90_smem_capacity_bytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{}))
: detail::compute_stage_count_or_override<
detail::sm90_smem_capacity_bytes,
ElementAMma,
ElementBMma,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsMixedInput,
cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
GmemLayoutATag_,
TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
GmemLayoutBTag_,
TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMmaArrayMixedInput<
DispatchPolicy,
TileShape_MNK,
ElementPairA,
StrideA,
ElementPairB,
StrideB,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity>;
static_assert(
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType,
class Enable = void>
struct CollectiveBuilderMixedInput {
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class DispatchPolicy,
class TileShape,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class TiledMma,
class GmemTiledCopyA,
class SmemLayoutAtomA,
class SmemCopyAtomA,
class TransformA,
class GmemTiledCopyB,
class SmemLayoutAtomB,
class SmemCopyAtomB,
class TransformB>
struct CollectiveMmaArrayMixedInput {
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/cuda_host_adapter.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/detail/collective/mixed_input_utils.hpp"
#define GROUP_SIZE 128
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <
int Stages,
class ClusterShape,
class KernelSchedule_,
class TileShape_,
class ElementAOptionalTuple,
class StrideA_,
class ElementBOptionalTuple,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMmaArrayMixedInput<
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule_>,
TileShape_,
ElementAOptionalTuple,
StrideA_,
ElementBOptionalTuple,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_> {
public:
enum class ConversionMode { DirectConvert, ConvertAndScale, ConvertAndScaleWithZero };
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule_>;
using TileShape = TileShape_;
using KernelSchedule = KernelSchedule_;
private:
template <class T>
friend struct detail::MixedGroupedGemmInputUtils;
using CollectiveType = CollectiveMma<
DispatchPolicy,
TileShape_,
ElementAOptionalTuple,
StrideA_,
ElementBOptionalTuple,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>;
using Utils = detail::MixedGroupedGemmInputUtils<CollectiveType>;
//
// Type Aliases
//
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>;
public:
static_assert(
cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value,
"Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in "
"[] are optional.");
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
static constexpr bool IsATransformed = cute::is_tuple<ElementAOptionalTuple>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
// For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is
// void.
using NonVoidElementScale = cute::conditional_t<cute::is_void_v<ElementScale>, float, ElementScale>;
using NonVoidElementZero = cute::conditional_t<cute::is_void_v<ElementZero>, float, ElementZero>;
using StrideA = StrideA_;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
using StrideB = StrideB_;
using InternalStrideB = cute::remove_pointer_t<StrideB>;
using StrideScale = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using NonVoidStrideScale =
cute::conditional_t<cute::is_void_v<StrideScale>, cute::Stride<_1, int64_t, int64_t>, StrideScale>;
static_assert(
(IsATransformed && (cutlass::gemm::detail::is_k_major<StrideA>() || is_layout<StrideA>::value ||
is_layout<InternalStrideA>::value)) ||
(!IsATransformed && (cutlass::gemm::detail::is_k_major<StrideB>() || is_layout<StrideB>::value ||
is_layout<InternalStrideB>::value)),
"The transformed type must be K-major.");
static_assert(
(IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) ||
((cutlass::gemm::detail::is_k_major<StrideA>() || is_layout<StrideA>::value ||
is_layout<InternalStrideA>::value) &&
(cutlass::gemm::detail::is_k_major<StrideB>() || is_layout<StrideB>::value ||
is_layout<InternalStrideB>::value)),
"The unscaled element must be 2 bytes OR both inputs must be K-major");
static_assert(
cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
"Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled].");
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using GmemTiledCopyScale = cute::SM90_TMA_LOAD;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using SmemCopyAtomScale = Copy_Atom<cute::AutoVectorizingCopy, NonVoidElementScale>;
// We must ensure the type to be scaled goes to RF
static constexpr bool SwapAB = !IsATransformed;
using SwappedStrideA = cute::conditional_t<!SwapAB, StrideA, StrideB>;
using SwappedStrideB = cute::conditional_t<!SwapAB, StrideB, StrideA>;
using InternalSwappedStrideA = cute::conditional_t<!SwapAB, InternalStrideA, InternalStrideB>;
using InternalSwappedStrideB = cute::conditional_t<!SwapAB, InternalStrideB, InternalStrideA>;
using SwappedSmemLayoutAtomA = cute::conditional_t<!SwapAB, SmemLayoutAtomA, SmemLayoutAtomB>;
using SwappedSmemLayoutAtomB = cute::conditional_t<!SwapAB, SmemLayoutAtomB, SmemLayoutAtomA>;
using SwappedSmemCopyAtomA = cute::conditional_t<!SwapAB, SmemCopyAtomA, SmemCopyAtomB>;
using SwappedSmemCopyAtomB = cute::conditional_t<!SwapAB, SmemCopyAtomB, SmemCopyAtomA>;
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using ConvertedElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using ConvertedElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
using RealSwappedElementA = cute::conditional_t<!SwapAB, ElementA, ElementB>;
using RealSwappedElementB = cute::conditional_t<!SwapAB, ElementB, ElementA>;
using SwappedElementA = cute::conditional_t<!SwapAB, ConvertedElementA, ConvertedElementB>;
using SwappedElementB = cute::conditional_t<!SwapAB, ConvertedElementB, ConvertedElementA>;
using TransformA = TransformA_;
using TransformB = TransformB_;
using SwappedTransformA = cute::conditional_t<!SwapAB, TransformA, TransformB>;
using SwappedTransformB = cute::conditional_t<!SwapAB, TransformB, TransformA>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<SwappedElementA> < 8;
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, SwappedElementA>;
using TmaElementScale = uint_bit_t<sizeof_bits_v<NonVoidElementScale>>; // in case we have array. translating to uint
// to satisfy tma descriptor's specialization
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static constexpr int NumProducerThreadEvents = 1;
using SmemLayoutAtomScale = Layout<Shape<decltype(cute::shape<0>(SwappedSmemLayoutAtomA{})), cute::Int<1>>>;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{})));
static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must evenly divide tile k shape.");
/// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(detail::get_smem_layout<DispatchPolicy::Stages>(
SwappedSmemLayoutAtomA{}, select<0, 2>(TileShape{}), InternalSwappedStrideA{}));
using SmemLayoutB = decltype(detail::get_smem_layout<DispatchPolicy::Stages>(
SwappedSmemLayoutAtomB{}, select<1, 2>(TileShape{}), InternalSwappedStrideB{}));
// It is assumed that the scales and zero-points share the same smem layout
using SmemLayoutScale = decltype(tile_to_shape(
SmemLayoutAtomScale{},
make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int<Stages>{}),
cute::conditional_t<
::cutlass::gemm::detail::is_major<0, NonVoidStrideScale>(),
Step<_2, _1, _3>,
Step<_1, _2, _3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(
not cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source A from rmem and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// To relax them, we need to handle loading more than 1 row of scales for every main loop iteration.
// We must also handle updating the pipeline transaction bytes on the fly.
static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1.");
private:
static constexpr ConversionMode get_conversion_mode() {
if constexpr (cute::is_void_v<ElementScale>) {
return ConversionMode::DirectConvert;
} else if constexpr (cute::is_void_v<ElementZero>) {
return ConversionMode::ConvertAndScale;
} else {
return ConversionMode::ConvertAndScaleWithZero;
}
}
public:
static constexpr ConversionMode KernelConversionMode = get_conversion_mode();
static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
static constexpr bool UseScaleLookupTable =
KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v<ElementScale>;
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment");
struct SharedStorage {
static constexpr int scale_elements = Utils::elements_per_smem_scale();
static constexpr int zero_elements = Utils::elements_per_smem_zero();
struct TensorStorage {
CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine<RealSwappedElementA, cute::cosize_v<SmemLayoutA>> smem_A;
CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
} tensors;
struct TensorMapStorage {
cute::TmaDescriptor smem_tensormap_A;
cute::TmaDescriptor smem_tensormap_B;
cute::TmaDescriptor smem_tensormap_scale;
cute::TmaDescriptor smem_tensormap_zero;
};
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
// kernel Arguments
// Host side kernel arguments
struct Arguments {
ElementA const** ptr_A;
StrideA dA;
ElementB const** ptr_B;
StrideB dB;
ElementScale const** ptr_S = nullptr;
NonVoidStrideScale const* dS{};
int chunk_size = 0;
ElementZero const** ptr_Z = nullptr;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using LayoutA =
decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideA{}, int32_t(0)), InternalSwappedStrideA{}));
using LayoutB =
decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideB{}, int32_t(0)), InternalSwappedStrideB{}));
using TMA_A = decltype(make_tma_copy<TmaElementA>(
GmemTiledCopyA{},
make_tensor(detail::get_logical_ptr(static_cast<SwappedElementA const*>(nullptr)), LayoutA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(detail::get_logical_ptr(static_cast<SwappedElementB const*>(nullptr)), LayoutB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Scale = decltype(make_tma_copy<TmaElementScale>(
GmemTiledCopyScale{},
make_tensor(
detail::get_logical_ptr(static_cast<NonVoidElementScale const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)),
NonVoidStrideScale{}),
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel
using TMA_Zero = decltype(make_tma_copy(
GmemTiledCopyScale{},
make_tensor(
detail::get_logical_ptr(static_cast<NonVoidElementZero const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)),
NonVoidStrideScale{}),
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel
TMA_A tma_load_a;
TMA_B tma_load_b;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
TMA_Scale tma_load_scale;
TMA_Zero tma_load_zero;
void* tensormaps;
SwappedElementA const** ptr_A;
SwappedStrideA ptr_dA;
SwappedElementB const** ptr_B;
SwappedStrideB ptr_dB;
NonVoidElementScale const** ptr_S;
NonVoidStrideScale const* dS;
NonVoidElementZero const** ptr_Z;
int64_t scale_k;
int chunk_size;
int reload_factor = (chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
InternalSwappedStrideA dA;
InternalSwappedStrideB dB;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(ProblemShape problem_shapes, Arguments const& args, void* workspace) {
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_K = get<2>(init_shape);
if constexpr (SwapAB) {
init_M = get<1>(init_shape);
init_N = get<0>(init_shape);
}
// Batches/Groups are managed by using appropriate pointers to input matrices
const uint32_t mock_L = 1;
SwappedElementA const* ptr_A_first_batch;
SwappedElementB const* ptr_B_first_batch;
SwappedStrideA ptr_dA;
SwappedStrideB ptr_dB;
InternalSwappedStrideA dA;
InternalSwappedStrideB dB;
if constexpr (not SwapAB) {
ptr_A_first_batch = reinterpret_cast<SwappedElementA const*>(args.ptr_A);
ptr_B_first_batch = reinterpret_cast<SwappedElementB const*>(args.ptr_B);
} else {
ptr_A_first_batch = reinterpret_cast<SwappedElementA const*>(args.ptr_B);
ptr_B_first_batch = reinterpret_cast<SwappedElementB const*>(args.ptr_A);
}
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
if constexpr (not SwapAB) {
ptr_dA = args.dA;
ptr_dB = args.dB;
} else {
ptr_dA = args.dB;
ptr_dB = args.dA;
}
dA = InternalSwappedStrideA{};
if constexpr (is_layout<InternalSwappedStrideA>::value) {
dA = make_layout(
transform_leaf(
dA.shape(),
[](auto x) {
if constexpr (not is_static_v<decltype(x)>) {
return static_cast<decltype(x)>(1);
} else {
return x;
}
}),
dA.stride());
}
dB = InternalSwappedStrideB{};
} else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
if constexpr (not SwapAB) {
dA = args.dA;
dB = args.dB;
} else {
dA = args.dB;
dB = args.dA;
}
ptr_dA = SwappedStrideA{};
ptr_dB = SwappedStrideB{};
}
Tensor tensor_a = make_tensor(ptr_A_first_batch, detail::get_gmem_layout(make_shape(init_M, init_K, mock_L), dA));
Tensor tensor_b = make_tensor(ptr_B_first_batch, detail::get_gmem_layout(make_shape(init_N, init_K, mock_L), dB));
typename Params::TMA_A tma_load_a = make_tma_copy<TmaElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
typename Params::TMA_Scale tma_load_scale{};
typename Params::TMA_Zero tma_load_zero{};
void* tensormaps = workspace;
auto args_setup =
[&](auto ptr_A, auto ptr_B, int64_t scale_k = 0, int chunk_size = 0, int reload_factor = 1) -> Params {
return {
tma_load_a,
tma_load_b,
TmaTransactionBytes,
tma_load_scale,
tma_load_zero,
tensormaps,
reinterpret_cast<SwappedElementA const**>(ptr_A),
ptr_dA,
reinterpret_cast<SwappedElementB const**>(ptr_B),
ptr_dB,
reinterpret_cast<NonVoidElementScale const**>(args.ptr_S),
args.dS,
reinterpret_cast<NonVoidElementZero const**>(args.ptr_Z),
scale_k,
chunk_size,
reload_factor,
dA,
dB};
};
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return SwapAB ? args_setup(args.ptr_B, args.ptr_A) : args_setup(args.ptr_A, args.ptr_B);
} else if constexpr (ModeHasScales) {
auto fake_scale_k = 1;
ElementScale const* ptr_S = reinterpret_cast<ElementScale const*>(args.ptr_S);
StrideScale dS{};
Tensor tensor_scale =
make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M, fake_scale_k, mock_L), dS));
tma_load_scale = make_tma_copy<TmaElementScale>(
GmemTiledCopyScale{},
tensor_scale,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return SwapAB ? args_setup(
args.ptr_B,
args.ptr_A,
fake_scale_k,
args.chunk_size,
(args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}))
: args_setup(
args.ptr_A,
args.ptr_B,
fake_scale_k,
args.chunk_size,
(args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}));
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
ElementZero const* ptr_Z = reinterpret_cast<ElementZero const*>(args.ptr_Z);
Tensor tensor_zero =
make_tensor(detail::get_logical_ptr(ptr_Z), make_layout(make_shape(init_M, fake_scale_k, mock_L), dS));
tma_load_zero = make_tma_copy(
GmemTiledCopyScale{},
tensor_zero,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
return SwapAB ? args_setup(
args.ptr_B,
args.ptr_A,
fake_scale_k,
args.chunk_size,
(args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}))
: args_setup(
args.ptr_A,
args.ptr_B,
fake_scale_k,
args.chunk_size,
(args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}));
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in to_underlying_arguments.");
}
}
template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Calculating workspace size
auto calculate_workspace_size = [SizeOfCuTensorMap, sm_count](uint32_t num_input_tensors) {
return num_input_tensors * SizeOfCuTensorMap * sm_count;
};
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return calculate_workspace_size(2);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies,
// followed by scale tensormap copies
return calculate_workspace_size(3);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies,
// followed by scale and zeros tensormap copies
return calculate_workspace_size(4);
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in get_workspace_size.");
}
}
template <class ProblemShape>
static cutlass::Status initialize_workspace(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace,
cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape problem_shapes, Arguments const& args) {
constexpr int tma_alignment_bits = 128;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
bool implementable = true;
if (problem_shapes.is_host_problem_shape_available()) {
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto get_stride = [](auto stride) {
if constexpr (cute::is_pointer_v<cute::decay_t<decltype(stride)>>) {
return *stride;
} else {
return stride;
}
};
auto dA = get_stride(args.dA);
auto dB = get_stride(args.dB);
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
detail::get_gmem_layout(cute::make_shape(M, K, L), dA));
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
detail::get_gmem_layout(cute::make_shape(N, K, L), dB));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
implementable = implementable && (args.ptr_S == nullptr);
implementable = implementable && (args.ptr_Z == nullptr);
} else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.chunk_size - 1) / args.chunk_size;
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(
cute::make_shape(scale_mn, scale_k, L), StrideScale{});
implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.chunk_size != 0;
implementable = implementable && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(
cute::make_shape(scale_mn, scale_k, L), StrideScale{});
implementable = implementable && (args.ptr_Z != nullptr);
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
}
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk();
static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk();
static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra();
static constexpr uint32_t TmaTransactionBytes =
TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra;
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the contract that the
// returned tuple must contain at least two elements, with the first two elements being:
// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
const int32_t mock_L = 1;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
shape(detail::get_gmem_layout(make_shape(M, K, mock_L), mainloop_params.dA))); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
shape(detail::get_gmem_layout(make_shape(N, K, mock_L), mainloop_params.dB))); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple(gA_mkl, gB_nkl);
} else if constexpr (ModeHasScales) {
// The real scale_k that actually works
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l)
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l)
Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in load_init.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in load_init.");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Perform a collective-scoped matrix multiply-accumulate
// Producer Perspective
template <class... Ts, class... TMs, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<Ts...> const& load_inputs,
cute::tuple<TMs...> const& input_tensormaps,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter,
int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs");
static_assert(sizeof...(TMs) == 2, "Direct convert needs two tensormaps");
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs");
static_assert(sizeof...(TMs) == 3, "Scaled convert needs three tensormaps");
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
static_assert(sizeof...(Ts) == 4, "Scaled and zero convert needs four inputs");
static_assert(sizeof...(TMs) == 4, "Scaled and zero convert needs four tensormaps");
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_s = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
auto extra_input_partitions = Utils::partition_extra_tma_inputs(
mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord);
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
if (cute::elect_one_sync()) {
copy(
mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a),
tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(
mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b),
tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do.
} else if constexpr (ModeHasScales) {
// scale copy
auto tSgS = get<0>(extra_input_partitions);
auto tSsS = get<1>(extra_input_partitions);
// Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma
// transaction bytes on the fly. We must do a ceiling divide here to correctly handle with chunk_size == K. In
// that case, we don't require that K is a multiple of the threadblock tile K
const int scale_load_k = *k_tile_iter / 1;
// const int scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when
// chunk_size == K.
if (cute::elect_one_sync()) {
copy(
mainloop_params.tma_load_scale.with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_s),
tSgS(_, _, _, scale_load_k),
tSsS(_, _, _, write_stage));
}
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// zero copy
auto tZgZ = get<2>(extra_input_partitions);
auto tZsZ = get<3>(extra_input_partitions);
if (cute::elect_one_sync()) {
copy(
mainloop_params.tma_load_zero.with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_s),
tZgZ(_, _, _, scale_load_k),
tZsZ(_, _, _, write_stage));
}
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
// This helps avoid early exit of blocks in Cluster.
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used
// then it would just be acquired since the phase was
// still inverted from make_producer_start_state.
pipeline.producer_tail(smem_pipe_write);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read,
FrgTensorC& accum,
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2.");
static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2.");
static_assert(
!cute::is_void_v<SwappedSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
static_assert(
cute::is_void_v<SwappedSmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
// Obtain warp index
int warp_idx = canonical_warp_idx_sync();
[[maybe_unused]] int warp_group_thread_idx = thread_idx % 128;
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
// Layout of warp group to thread mapping
static_assert(
stride<0>(typename TiledMma::BLayout{}) == 0 and
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{}, Int<NumThreadsPerWarpGroup>{});
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = mma_thread_slice.partition_A(sA);
auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
// Allocate fragments and descriptors
Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_, _, Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrA_load = [&] {
if constexpr (not is_layout<InternalSwappedStrideA>::value) {
// Make register tensor with MMA layout
return make_fragment_like<RealSwappedElementA>(tCrA_mma);
} else {
// Make register tensor matching smem layout, converter will take care of de-swizzling
return make_tensor_like<RealSwappedElementA>(tCsA(_, _, _, Int<0>{}));
}
}();
Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
//
// Copy Atom A retiling
//
auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx);
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K)
// Partition of thread -> shared and thread -> RF
auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors);
auto copy_partitions_extra_info =
Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
multiply_add<ElementAccumulator> fma;
constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())();
constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE;
cute::array<decltype(make_fragment_like(accum)), NumChunksPerTileK> intermediate_array;
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7);
static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK");
ConsumerToken barrier_token = {BarrierStatus::WaitAgain};
// First k tile
{
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
// copy smem->rmem for A operand
Utils::copy_tensors_MK(
smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage);
if (K_BLOCK_MAX > 1) {
Utils::copy_tensors_MK(
smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, read_stage);
}
// src: tCrA_load, dst: tCrA_mma
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int mma_id = 0; mma_id < NumMMAsPerChunk; ++mma_id) {
int k_block = chunk_id * NumMMAsPerChunk + mma_id;
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate_array[chunk_id]);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
if (k_block < K_BLOCK_MAX - 2) {
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
k_block + 2,
read_stage);
}
if (k_block < K_BLOCK_MAX - 1) {
Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1);
}
}
}
CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
warpgroup_fence_operand(intermediate_array[chunk_id_]);
// Apply the group-wise scaling
// tCrS ((4, _2, _2), MMA_M, _1)
// accum ((2, _2, _2), MMA_M, _1)
auto tCrS = cute::get<1>(partitioned_extra_info);
for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) {
for (int m = 0; m < size<0, 1>(accum); m++) {
for (int n = 0; n < size<0, 2>(accum); n++) {
for (int e = 0; e < size<0, 0>(accum); e++) {
auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0);
auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0);
if (chunk_id_ == 0) {
accum(accum_coord) =
intermediate_array[chunk_id_](accum_coord) * static_cast<float>(tCrS(scale_coord)[0]);
} else {
accum(accum_coord) =
fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]),
accum(accum_coord));
}
}
}
}
}
}
--k_tile_count;
if (k_tile_count > 0) {
// Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first
// mma.
pipeline.consumer_wait(smem_pipe_read, barrier_token);
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
0,
smem_pipe_read.index());
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
1,
smem_pipe_read.index());
warpgroup_wait<K_WAIT_MAX>();
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
}
}
if (k_tile_count == 0) {
return;
}
// Mainloop GMMAs
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 1; --k_tile_count) {
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int mma_id = 0; mma_id < NumMMAsPerChunk; ++mma_id) {
int k_block = chunk_id * NumMMAsPerChunk + mma_id;
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate_array[chunk_id]);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can
// release prior barrier
if (k_block == K_BLOCK_MAX - 1) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
if (k_block == 0) {
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
}
if (k_block == K_BLOCK_MAX - 1) {
// The last k_block
CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
warpgroup_fence_operand(intermediate_array[chunk_id_]);
// Apply the group-wise scaling
auto tCrS = cute::get<1>(partitioned_extra_info);
for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) {
for (int m = 0; m < size<0, 1>(accum); m++) {
for (int n = 0; n < size<0, 2>(accum); n++) {
for (int e = 0; e < size<0, 0>(accum); e++) {
auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0);
auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0);
accum(accum_coord) =
fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]),
accum(accum_coord));
}
}
}
}
}
pipeline.consumer_wait(smem_pipe_read, barrier_token);
// copy scales when passing k_block=0
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
0,
smem_pipe_read.index());
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
1,
smem_pipe_read.index());
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
} else {
if (k_block < K_BLOCK_MAX - 2) {
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
k_block + 2,
read_stage);
}
Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1);
}
}
}
}
{
//
// Last k tile
//
Tensor intermediate = make_fragment_like(accum);
int read_stage = smem_pipe_read.index();
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>();
if (k_block == K_BLOCK_MAX - 1) {
// release prior barrier
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
if (k_block < K_BLOCK_MAX - 2) {
Utils::copy_tensors_MK(
smem_tiled_copy_A,
tCsA,
tCrA_copy_view,
partitioned_extra_info,
copy_partitions_extra_info,
k_block + 2,
read_stage);
}
if (k_block < K_BLOCK_MAX - 1) {
Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1);
}
if ((k_block + 1) % NumMMAsPerChunk == 0) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(intermediate);
// Apply the group-wise scaling
auto tCrS = cute::get<1>(partitioned_extra_info);
for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) {
for (int m = 0; m < size<0, 1>(accum); m++) {
for (int n = 0; n < size<0, 2>(accum); n++) {
for (int e = 0; e < size<0, 0>(accum); e++) {
auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0);
auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0);
int scale_idx = k_block / NumMMAsPerChunk;
accum(accum_coord) = fma(
intermediate(accum_coord), static_cast<float>(tCrS(scale_coord)[scale_idx]), accum(accum_coord));
}
}
}
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = 1;
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Methods to perform different parts of TMA/Tensormap modifications
//
CUTLASS_DEVICE auto tensormaps_init(
Params const& mainloop_params, TensorMapStorage& shared_tensormaps, int32_t sm_count, int32_t sm_idx) {
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
cute::TmaDescriptor* tma_desc_scale = &gmem_tensormap[sm_idx + 2 * sm_count];
cute::TmaDescriptor* tma_desc_zero = &gmem_tensormap[sm_idx + 3 * sm_count];
// Bringing tensormaps from params to smem for modification later
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
if (cute::elect_one_sync()) {
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
}
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{});
if (cute::elect_one_sync()) {
copy(recast<uint128_t>(pS_tensormap), recast<uint128_t>(sS_tensormap));
}
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor pZ_tensormap = make_tensor(mainloop_params.tma_load_zero.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sZ_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_zero), Int<1>{}, Int<1>{});
if (cute::elect_one_sync()) {
copy(recast<uint128_t>(pZ_tensormap), recast<uint128_t>(sZ_tensormap));
}
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in tensormaps_init.");
}
__syncwarp();
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple(tma_desc_a, tma_desc_b);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale, tma_desc_zero);
} else {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in tensormaps_init.");
}
}
// Replace address for the global tensor (to be done by single thread)
CUTLASS_DEVICE
void tensormaps_replace_global_address(
TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_batch) {
// Replacing global_address for the next batch
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_B, mainloop_params.ptr_B[next_batch]);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]);
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_address.");
}
}
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE void tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t K = get<2>(problem_shape_mnkl);
// Replace all dims for consistency
constexpr int MaxTensorRank = 5;
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1, 1, 1, 1, 1};
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0, 0, 0, 0, 0};
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1, 1, 1, 1, 1};
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0, 0, 0, 0, 0};
cute::array<uint32_t, MaxTensorRank> prob_shape_scale = {1, 1, 1, 1, 1};
cute::array<uint64_t, MaxTensorRank> prob_stride_scale = {0, 0, 0, 0, 0};
cute::array<uint32_t, MaxTensorRank> prob_shape_zero = {1, 1, 1, 1, 1};
cute::array<uint64_t, MaxTensorRank> prob_stride_zero = {0, 0, 0, 0, 0};
SwappedElementA const* ptr_A = nullptr;
Tensor tensor_a =
make_tensor(ptr_A, detail::get_gmem_layout(make_shape(M, K, Int<1>{}), mainloop_params.ptr_dA[next_group]));
SwappedElementB const* ptr_B = nullptr;
Tensor tensor_b =
make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N, K, Int<1>{}), mainloop_params.ptr_dB[next_group]));
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
NonVoidElementScale const* ptr_S = nullptr;
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
Tensor tensor_scale =
make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(
mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
ElementZero const* ptr_Z = nullptr;
// auto scale_k = K / mainloop_params.chunk_size;
auto scale_k = K / GROUP_SIZE;
Tensor tensor_zero =
make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(
mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero);
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_tensor_properties.");
}
// Convert strides to byte strides
for (uint64_t& stride : prob_stride_A) {
stride = (stride * sizeof_bits_v<SwappedElementA>) / 8;
}
for (uint64_t& stride : prob_stride_B) {
stride = (stride * sizeof_bits_v<SwappedElementB>) / 8;
}
for (uint64_t& stride : prob_stride_scale) {
stride = (stride * sizeof_bits_v<NonVoidElementScale>) / 8;
}
for (uint64_t& stride : prob_stride_zero) {
stride = (stride * sizeof_bits_v<NonVoidElementScale>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(
shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(
shared_tensormaps.smem_tensormap_B, prob_shape_B, prob_stride_B);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
cute::tma_descriptor_replace_dims_strides_in_shared_mem(
shared_tensormaps.smem_tensormap_scale, prob_shape_scale, prob_stride_scale);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
cute::tma_descriptor_replace_dims_strides_in_shared_mem(
shared_tensormaps.smem_tensormap_zero, prob_shape_zero, prob_stride_zero);
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_tensor_properties.");
}
}
template <class... TMs, class ProblemShape_MNKL>
CUTLASS_DEVICE void tensormaps_perform_update(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
cute::tuple<TMs...> const& input_tensormaps,
ProblemShape_MNKL problem_shape_mnkl,
int32_t next_batch) {
if (cute::elect_one_sync()) {
// Replacing global_address for the next batch
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
if constexpr (IsGroupedGemmKernel) {
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_mnkl);
}
}
}
template <class... TMs>
CUTLASS_DEVICE void
tensormaps_cp_fence_release(TensorMapStorage& shared_tensormaps, cute::tuple<TMs...> const& input_tensormaps) {
// Entire warp must do this (i.e. it's aligned)
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale);
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero);
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_cp_fence_release.");
}
}
// The entire warp must call this function collectively (that is, the instructions are aligned)
template <class... TMs>
CUTLASS_DEVICE void tensormaps_fence_acquire(cute::tuple<TMs...> const& input_tensormaps) {
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps));
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps));
} else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) {
static_assert(
cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in tensormaps_fence_acquire.");
}
}
template <class InputTensors, class ProblemShape_MNKL>
CUTLASS_DEVICE InputTensors tensors_perform_update(
InputTensors const& input_tensors,
[[maybe_unused]] Params const& mainloop_params,
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
[[maybe_unused]] int32_t next_batch) {
return input_tensors;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
cutlass_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
return;
}
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
get_cutlass_w4a8_moe_mm_data_caller(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
__global__ void int4_fp8_get_group_gemm_starts(
int32_t* expert_offsets,
ElementA** a_offsets,
ElementB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
cutlass::bfloat16_t** b_scales_offsets,
ElementA* a_base_as_int,
ElementB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
cutlass::bfloat16_t* b_scales_base_as_int,
int64_t n,
int64_t k,
bool per_act_token,
bool per_out_ch) {
int expert_id = threadIdx.x;
int32_t expert_offset = expert_offsets[expert_id];
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
out_tensors.size(1), \
a_tensors.size(1), \
per_act_token, \
per_out_ch); \
}
namespace {
void run_int4_fp8_get_group_gemm_starts(
torch::Tensor const& expert_offsets,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
using namespace cute;
namespace {
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1);
if (n == 4096 && k == 7168) {
// group gemm 1
if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 2048) {
// group gemm 2
if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
}
} // namespace
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
dispatch_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
}
#pragma once
/**
* @file w4a8_grouped_mm_c3x.cuh
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
* precision
*
* This file implements a grouped GEMM operation that multiplies FP8 matrices
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
* Cores for mixed precision arithmetic.
*
* Key features:
* - Supports grouped GEMM operations with multiple experts
* - Uses FP8 (e4m3) for matrix A
* - Uses INT4 quantization for matrix B with per-block scaling
* - Implements preprocessing for INT4 encoding and scale packing
* - Optimized for Hopper architecture with Tensor Core operations
*/
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
#include "w4a8_get_group_starts.cuh"
using namespace cute;
namespace {
// Type definitions
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// constexpr int TileShapeK = 512;
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
// using ClusterShape = Shape<_1, _1, _1>;
// Layout configurations
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Transposed layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm {
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC_Transpose*,
AlignmentC,
ElementD,
LayoutD_Transpose*,
AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
ArchTag,
OperatorClass,
cute::tuple<QuantType, ElementScalePacked>,
LayoutB_Transpose*,
AlignmentB,
MmaType,
LayoutA_Transpose*,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
// Define the final kernel and GEMM operation types
using GemmKernelScaleOnly =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
};
/**
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
*
* This function performs multiple GEMM operations in parallel where each
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
* applying per-channel scaling factors. It's designed for efficient execution
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
* mixed precision arithmetic.
*
* The function includes preprocessing steps for both INT4 tensors and scale
* factors to ensure optimal performance and correct operation.
*
* @param d_tensors Output tensor D with shape [total_m, total_n]
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
* [total_m, K]
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
* shape [E, N, K/2]
* @param a_scales Tensor containing A matrix scale factors
* @param b_scales Tensor containing B matrix scale factors with shape [E,
* K//512, N*4]
* @param expert_offsets Tensor containing expert offsets for determining group
* boundaries (int32)
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
* 3] (M, N, K for each group) (int32)
* @param a_strides Stride information for A tensors
* @param b_strides Stride information for B tensors
* @param d_strides Stride information for D tensors
* @param s_strides Stride information for scale tensors
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
*/
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
template <typename Gemm>
void cutlass_w4a8_group_gemm_caller(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
using Args = typename Gemm::GemmScaleOnly::Arguments;
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
// Check inputs
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
// Check tensor shapes
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_tensors.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f;
fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
run_int4_fp8_get_group_gemm_starts(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_tensors,
b_tensors,
d_tensors,
a_scales,
b_scales);
arguments = Args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)},
{fusion_args,
nullptr,
nullptr,
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
hw_info};
// Instantiate and run GEMM
typename Gemm::GemmScaleOnly gemm;
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM implementation not supported");
}
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM initialization failed");
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM execution failed");
}
}
} // namespace
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes_w4a8(
const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
__global__ void compute_expert_offsets_w4a8(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset;
}
}
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}
...@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct( ...@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct(
int64_t page_size, int64_t page_size,
int64_t num_layers); int64_t num_layers);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/* /*
* From FlashInfer * From FlashInfer
*/ */
......
...@@ -19,6 +19,7 @@ from sgl_kernel.attention import ( ...@@ -19,6 +19,7 @@ from sgl_kernel.attention import (
merge_state, merge_state,
merge_state_v2, merge_state_v2,
) )
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import ( from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace, apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm, fused_add_rmsnorm,
......
import torch
def get_cutlass_w4a8_moe_mm_data(
topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def cutlass_w4a8_moe_mm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
experts_offsets: torch.tensor,
problem_sizes: torch.tensor,
a_strides: torch.tensor,
b_strides: torch.tensor,
d_strides: torch.tensor,
s_strides: torch.tensor,
chunk_size: int = 128,
topk: int = 8,
):
"""
Perform grouped matrix multiplication between int4 weights and fp8 activations.
This function executes multiple GEMM operations in parallel, which is useful for
scenarios like Mixture of Experts (MoE) where different inputs go through different
experts. The implementation leverages NVIDIA Hopper architecture features for
optimal performance with quantized weights.
Args:
d: Output matrices of shape [total_m, total_n]
a: Activation matrices in FP8 (float_e4m3_t) format
Each tensor should be of shape [total_m, K] in row-major layout
b: Weight matrices in packed int4 format
Each tensor should be of shape [E, N, K//2] in column-major layout
where each byte contains two 4-bit integers
a_scales: Scale factors for the inputs
b_scales: Scale factors for the quantized weights
Each tensor should be of shape [E, K//512, N*8]
experts_offsets: Tensor containing expert offsets for determining group boundaries
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
a_strides: Strides information for A matrices
b_strides: Strides information for B matrices
d_strides: Strides information for D matrices
s_strides: Strides information for b_scales matrices
chunk_size: Number of elements each scale value applies to (K//512), default to 128
Requirements:
- All tensors must be on a CUDA device
- Requires an NVIDIA Hopper GPU (H100)
- A tensors must be in float8_e4m3fn format
- B tensors must contain packed int4 values (stored as int8)
Note:
The function computes: D = (A * (B * scales))
for each group of tensors in parallel
"""
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
d,
a,
b,
a_scales,
b_scales,
experts_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk,
)
import pytest
import torch
from sgl_kernel import cutlass_w4a8_moe_mm
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
if int4_values_interleaved.shape[-1] % 2 != 0:
raise ValueError(
"the last dim size of int4_values_interleaved tensor must be even."
)
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
low_nibbles = input_tensor_int8[..., 0::2]
high_nibbles = input_tensor_int8[..., 1::2]
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
return w_q, w_scale
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Test parameters
num_experts = 1
m = batch_size # batch size
k = 512 # input dimension
n = 1024 # output dimension
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(f"\nTesting with batch_size={batch_size}")
# Create input tensors with ones
if debug:
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(m, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# Create expert offsets and problem sizes
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
# Quantize input
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
# Create output tensor
c = torch.empty((m, n), dtype=torch.float16, device=device)
cutlass_w4a8_moe_mm(
c,
a_q,
w,
a_scale,
w_scale,
expert_offsets[:-1],
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
c = c.to(dtype)
# Reference implementation
experts_selection_result = torch.full((m,), 0)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
# torch.set_printoptions(threshold=10_000)
print(f" FAILURE: tensors are NOT close.")
print(f" Ref tensor: {c_ref.flatten()}")
print(f" Cutlass tensor: {c.flatten()}")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024])
@pytest.mark.parametrize("n", [1024, 2048])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
)
if debug:
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(batch_size, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# random select experts
experts_selection_result = torch.randint(
0, num_experts, (batch_size,), device=device
)
permutation = torch.argsort(experts_selection_result)
expert_token_counts = torch.bincount(
experts_selection_result, minlength=num_experts
)
# Create problem sizes and offsets for active experts
problem_sizes = []
for i in range(num_experts):
problem_sizes.append([n, expert_token_counts[i].item(), k])
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
expert_offsets = []
offset = 0
for i in range(num_experts):
expert_offsets.append(offset)
offset += problem_sizes[i][1].item()
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
# Permute input and quantize
a_perm = a[permutation]
a_q_perm = (
torch.clamp((a_perm / a_scale), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(device)
)
# Create stride tensors
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
cutlass_w4a8_moe_mm(
c_perm,
a_q_perm,
w,
a_scale,
w_scale,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
# Un-permute the result
c = torch.empty_like(c_perm)
c[permutation] = c_perm
c = c.to(dtype)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
print(f" FAILURE: tensors are NOT close.")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
dtype = torch.bfloat16
c_ref = torch.zeros_like(c)
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
for i in range(num_experts):
token_idx = torch.where(experts_selection_result == i)[0]
if len(token_idx) == 0:
continue
a = a_q[token_idx]
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
c = c.to(dtype)
c_ref[token_idx] = c.to(dtype)
return c_ref
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment