"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "04257666394edc3a2b98ba4f24b6f2cda5da026f"
Commit 45559a1f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Allow mma fallback when wgmma is not supported (#206)

* Enhance error message for constant size stack allocation in CUDA codegen. Include the actual constant size and buffer variable name in the error output for better debugging.

* Refactor GEMM and Bulk Copy operations to enhance layout handling and support for Hopper architecture

- Update `ComputeWarpPartition` to include a new parameter for Hopper WGMMA support.
- Modify layout checks in `LowerBulkCopy` to accommodate new GEMM layout types.
- Enhance layout inference logic in `InferLayout` for better compatibility with Hopper architecture.
- Include necessary header files for built-in operations and layout inference improvements.

* lint fix

* Remove unused builtin.h include directive

* Update include path for builtin.h
parent 6ffae0f2
...@@ -169,9 +169,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -169,9 +169,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto stride = as_const_int(shared_layout->InputShape()[0]); auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]); auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr); ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
*stride, *continuous, *stride, *continuous,
shared_tensor->dtype.bits()))) { shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(
shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()( } else if (StructuralEqual()(
shared_layout, shared_layout,
......
...@@ -9,9 +9,11 @@ ...@@ -9,9 +9,11 @@
#include "gemm.h" #include "gemm.h"
#include "builtin.h"
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
...@@ -52,10 +54,12 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -52,10 +54,12 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
} }
} }
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
Target target) const { bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
if (TargetIsHopper(target)) { bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads."; ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
if (this->policy == GemmWarpPolicy::kFullRow || if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) { this->policy == GemmWarpPolicy::kSquare) {
...@@ -108,9 +112,12 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -108,9 +112,12 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
warp_size = 64; warp_size = 64;
} }
ICHECK(T.block_size % warp_size == 0); bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(T.block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_ss"; std::string op_name = "tl::gemm_ss";
if (A.scope() == "local.fragment") { if (A.scope() == "local.fragment") {
...@@ -125,6 +132,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -125,6 +132,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (TargetIsCDNA(T.target)) { if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack // for cdna gemm, we need to specify kPack
ss << ", " << kPack; ss << ", " << kPack;
} else if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
} }
ss << ">"; ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A; auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
...@@ -199,10 +208,18 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -199,10 +208,18 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
const int warp_size = 32; const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (T.block_size / warp_size % 4 == 0);
if (!maybe_wgmma) {
LOG(WARNING)
<< "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
}
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma);
auto fragment = auto fragment =
makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits()); maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
......
...@@ -30,7 +30,9 @@ public: ...@@ -30,7 +30,9 @@ public:
} policy; } policy;
private: private:
std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const; std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
tir::Buffer A, B, C; tir::Buffer A, B, C;
......
...@@ -1305,7 +1305,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -1305,7 +1305,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
} else { } else {
size_t constant_size = op->ConstantAllocationSize(); size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now"; << "Can only handle constant size stack allocation for now, but get "
<< constant_size << " for " << op->buffer_var->name_hint;
if (scope.find("wmma.") == 0) { if (scope.find("wmma.") == 0) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm90.hpp> #include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp> #include <cute/atom/mma_atom.hpp>
#include <cutlass/arch/barrier.h> #include <cutlass/arch/barrier.h>
...@@ -13,6 +14,7 @@ namespace cute { ...@@ -13,6 +14,7 @@ namespace cute {
using namespace SM90; using namespace SM90;
namespace tl_wgmma {
template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K> template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() { CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
auto BLK_MN0 = size<0>(BLK_MN{}); auto BLK_MN0 = size<0>(BLK_MN{});
...@@ -101,8 +103,7 @@ public: ...@@ -101,8 +103,7 @@ public:
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{}, SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{})); conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
// static_assert(num_warp_n == 1); static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
static_assert(num_warp_m % 4 == 0);
template <int wg_wait = 0> template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
...@@ -209,26 +210,398 @@ public: ...@@ -209,26 +210,398 @@ public:
} }
}; };
} // namespace tl_wgmma
namespace tl_mma {
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n>
struct DispatchInstruction;
using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
};
#endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<64, N, K, true,
typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false,
typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;
using TileMma = TiledMMA<typename Instruction::MMA,
Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
typename Instruction::MMA_Group>;
template <class... Args>
static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
return layout;
}
// In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
// the original layout fail to compile, currently using this as a workaround
template <class... Args>
static CUTE_DEVICE auto
remove_swizzle(ComposedLayout<Args...> const &layout) {
if constexpr (sizeof(A_type) == 2)
return layout.layout_b();
else
return layout;
}
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);
Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
};
} // namespace tl_mma
} // namespace cute } // namespace cute
namespace tl { namespace tl {
namespace tl_mma {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int wg_wait = 0, typename A_type, typename B_type, bool trans_B, typename A_type, typename B_type, typename C_type>
typename C_type> CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { using MMA =
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum); MMA::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, int wg_wait = 0, typename A_type, typename B_type, bool trans_B, typename A_type, typename B_type, typename C_type>
typename C_type> CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { using MMA =
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>; trans_B, A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum); MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
} // namespace tl_mma
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type,
typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
using MMA =
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum);
} else {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool use_wgmma = true, int wg_wait = 0, typename A_type,
typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
using MMA =
cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum);
} else {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
} }
template <int num_mma> TL_DEVICE void wait_wgmma() { template <int num_mma> TL_DEVICE void wait_wgmma() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment