Unverified Commit fe70549f authored by FeiyangChen's avatar FeiyangChen Committed by GitHub
Browse files

[Feat] Support mma gemm with stride (#701)



* gemm_with_stride sm89

* fix offset issue

* bug fix

* format

* sm80 support

* add sm90

* add testing

* format

* add static_assert for wgmma

* Enhance error message for inner_box_dim validation in LowerBulkCopy

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 569b0127
......@@ -273,7 +273,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< "inner_box_dim: " << *inner_box_dim << " is not divisible by 256";
instruction_dim = 256;
}
ICHECK((*inner_box_dim) % instruction_dim == 0);
ICHECK((*inner_box_dim) % instruction_dim == 0)
<< "inner_box_dim: " << *inner_box_dim
<< " is not divisible by instruction_dim: " << instruction_dim;
desc.smem_box.Set(0, PrimExpr(instruction_dim));
int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes();
......
......@@ -47,14 +47,18 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
clear_accum = args[9].as<Bool>().value();
if (args.size() > 10) {
kPack = args[10].as<IntImm>().value()->value;
stride_A = args[10].as<IntImm>().value()->value;
stride_B = args[11].as<IntImm>().value()->value;
offset_A = args[12].as<IntImm>().value()->value;
offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
kPack = args[14].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 11) {
wg_wait = args[11].as<IntImm>().value()->value;
if (args.size() > 15) {
wg_wait = args[15].as<IntImm>().value()->value;
}
}
......@@ -284,6 +288,19 @@ bool Gemm::CheckWGMMA() const {
}
}
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char *arch_str = s.value().c_str();
if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
arch_int = atoi(&arch_str[3]);
} else {
arch_int = 0;
}
return arch_int;
}
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
......@@ -301,6 +318,10 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
ss << ", " << stride_A << ", " << stride_B;
ss << ", " << offset_A << ", " << offset_B;
}
if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
......
......@@ -45,6 +45,8 @@ private:
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
......
......@@ -73,135 +73,143 @@ template <int N, int num_warp_n, bool transpose> struct SelectCopy {
DefaultCopy>;
};
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
template <int Bits, int N, int K, bool K_inner, int num_warp_n, int leading_dim,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int stride = leading_dim;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
K_inner, Layout<Shape<Int<N>, Int<leading_dim>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<leading_dim>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = DefaultCopy;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
......@@ -215,10 +223,10 @@ public:
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n, ldb>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
......@@ -244,12 +252,38 @@ public:
return layout;
}
template <int offset, int NN, int KK, bool trans, int lddim, typename Engine0,
typename Layout0>
static CUTE_DEVICE auto get_region_tensor(Tensor<Engine0, Layout0> &sa) {
if constexpr (offset == 0) {
return composition(
sa,
Layout<Shape<Int<NN>, Int<KK>>,
Stride<_1, typename std::conditional<trans, Int<NN>,
Int<lddim>>::type>>{});
} else {
if constexpr (trans) {
static_assert(offset % KK == 0, "Offset must be a multiple of K");
constexpr int offset_n = offset / KK;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, _0{},
Int<offset_n>{});
} else {
static_assert(offset % NN == 0, "Offset must be a multiple of N");
constexpr int offset_n = offset / NN;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, Int<offset_n>{},
_0{});
}
}
}
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -287,8 +321,9 @@ public:
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
......@@ -322,8 +357,9 @@ public:
static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -360,29 +396,32 @@ public:
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
......
......@@ -91,135 +91,143 @@ template <int N, int num_warp_n, bool transpose> struct SelectCopy {
DefaultCopy>;
};
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
template <int Bits, int N, int K, bool K_inner, int num_warp_n, int leading_dim,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int stride = leading_dim;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
K_inner, Layout<Shape<Int<N>, Int<leading_dim>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<leading_dim>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = DefaultCopy;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
......@@ -233,10 +241,10 @@ public:
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n, ldb>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
......@@ -262,12 +270,44 @@ public:
return layout;
}
template <int offset, int NN, int KK, bool trans, int lddim, typename Engine0,
typename Layout0>
static CUTE_DEVICE auto get_region_tensor(Tensor<Engine0, Layout0> &sa) {
if constexpr (offset == 0) {
return composition(
sa,
Layout<Shape<Int<NN>, Int<KK>>,
Stride<_1, typename std::conditional<trans, Int<NN>,
Int<lddim>>::type>>{});
} else {
if constexpr (trans) {
static_assert(offset % KK == 0, "Offset must be a multiple of K");
constexpr int offset_n = offset / KK;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, _0{},
Int<offset_n>{});
} else {
static_assert(offset % NN == 0, "Offset must be a multiple of N");
constexpr int offset_n = offset / NN;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, Int<offset_n>{},
_0{});
}
}
}
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
// Tensor sA = composition(sA_all, Layout<Shape<Int<M>, Int<K>>,
// Stride<_1, typename std::conditional<trans_A, Int<lda>,
// Int<M>>::type>>{});
// Tensor sB = composition(sB_all, Layout<Shape<Int<N>, Int<K>>,
// Stride<_1, typename std::conditional<trans_B, Int<N>,
// Int<ldb>>::type>>{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -306,8 +346,11 @@ public:
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
// Tensor sB = flat_divide(sB_all, Shape<Int<N>, Int<K>>{})(_, _, _0{},
// _0{});
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
......@@ -342,8 +385,11 @@ public:
static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
// Tensor sA = flat_divide(sA_all, Shape<Int<M>, Int<K>>{})(_, _, _0{},
// _0{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -380,29 +426,32 @@ public:
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
......
......@@ -194,16 +194,16 @@ struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
};
#endif
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
template <int Bits, int N, int K, bool K_inner, int num_warp_n, int leading_dim,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int stride = leading_dim;
static constexpr int padded =
stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
K_inner, Layout<Shape<Int<N>, Int<leading_dim>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<leading_dim>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
......@@ -224,124 +224,132 @@ template <int N, int num_warp_n, bool transpose> struct SelectCopy {
DefaultCopy>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = typename SelectCopy<N, num_warp_n, false>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename SelectCopy<N, num_warp_n, true>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, true, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Layout =
decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
using Copy = DefaultCopy;
};
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
typename std::enable_if<leading_dim % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Layout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
......@@ -355,10 +363,11 @@ public:
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n, ldb>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......@@ -383,12 +392,38 @@ public:
return layout;
}
template <int offset, int NN, int KK, bool trans, int lddim, typename Engine0,
typename Layout0>
static CUTE_DEVICE auto get_region_tensor(Tensor<Engine0, Layout0> &sa) {
if constexpr (offset == 0) {
return composition(
sa,
Layout<Shape<Int<NN>, Int<KK>>,
Stride<_1, typename std::conditional<trans, Int<NN>,
Int<lddim>>::type>>{});
} else {
if constexpr (trans) {
static_assert(offset % KK == 0, "Offset must be a multiple of K");
constexpr int offset_n = offset / KK;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, _0{},
Int<offset_n>{});
} else {
static_assert(offset % NN == 0, "Offset must be a multiple of N");
constexpr int offset_n = offset / NN;
return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, Int<offset_n>{},
_0{});
}
}
}
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -426,8 +461,9 @@ public:
static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
......@@ -461,8 +497,9 @@ public:
static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
C_type_raw *pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
......@@ -503,67 +540,86 @@ namespace tl {
namespace tl_mma {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type, typename B_type,
typename C_type>
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, A_type, B_type, C_type>;
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
} // namespace tl_mma
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
"Hopper wgmma doesn't support custom stride for A");
static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
"Hopper wgmma doesn't support custom stride for B");
static_assert(offset_a == 0 && offset_b == 0,
"offset_a and offset_b must be zero for wgmma");
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum);
} else {
using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
"Hopper wgmma doesn't support custom stride for A");
static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
"Hopper wgmma doesn't support custom stride for B");
static_assert(offset_a == 0 && offset_b == 0,
"offset_a and offset_b must be zero for wgmma");
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum);
} else {
using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum,
A_type, B_type, C_type>;
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
}
......
import tilelang.testing
import tilelang
import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K * 2), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N * 2), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Clear local accumulation
T.clear(C_local)
T.clear(B_shared)
T.clear(A_shared)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# T.copy(A[by * block_M, ko * block_K], A_shared)
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k + block_K] = A[by * block_M + i, ko * block_K + k]
# Copy tile of B
# T.copy(B[ko * block_K, bx * block_N], B_shared)
for i, k in T.Parallel(block_K, block_N):
B_shared[i, k] = B[ko * block_K + i, bx * block_N + k]
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared[:, block_K:], B_shared[0:block_K, 0:block_N], C_local)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, block_K: int):
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(7, 5)
def test_tilelang_kernel_gemm_with_stride():
run_gemm_with_stride_ss(128, 128, 64, 32, 32, 32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -69,10 +69,32 @@ def gemm(
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
......@@ -90,6 +112,9 @@ def gemm(
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
......@@ -105,12 +130,33 @@ def gemm(
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices)):
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
......@@ -127,6 +173,10 @@ def gemm(
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment