"docs/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "785ff7d707ad3c7601a55d5c997ea98087111c4f"
Commit d1c15bc5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support cute mma tile mxn8ky (#434)

* [Enhancement] Improve error handling in layout inference and update profiler type in tests

* Added a detailed error message in the layout inference for local.fragment to clarify the requirement for trans_B.
* Updated the profiler type in the cumulative sum test from TensorSupplyType.One to TensorDistributionType.Randn for better profiling accuracy.

* lint fix

* [Refactor] Update OperandTraits to include num_warp_n parameter

* Modified OperandTraits templates across gemm_sm80.h, gemm_sm89.h, and gemm_sm90.h to include an additional num_warp_n parameter for improved flexibility in layout and copy operations.
* Adjusted Copy type selection based on the new parameter to enhance performance and adaptability in various scenarios.

* lint fix

* [Refactor] Update DispatchInstruction templates to include N parameter

* Modified DispatchInstruction templates in gemm_sm80.h, gemm_sm89.h, and gemm_sm90.h to include an additional N parameter, enhancing flexibility in tile size calculations.
* Adjusted MMA_Group definitions to use std::min for improved handling of warp sizes, ensuring better performance and adaptability in various scenarios.
parent 3d206235
...@@ -213,7 +213,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -213,7 +213,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1)); B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false); ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this";
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n)); results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
} else { } else {
ICHECK(0); ICHECK(0);
......
...@@ -10,53 +10,54 @@ ...@@ -10,53 +10,54 @@
namespace cute { namespace cute {
template <typename A_type, typename B_type, typename C_type, int num_warp_m, template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n> int num_warp_n, int N>
struct DispatchInstruction; struct DispatchInstruction;
using _X = Underscore; using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m, struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m, struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> { struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> { struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
}; };
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
}; };
#endif #endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void> template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits { struct OperandTraits {
// Primary template, use padded layout and default copy // Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N; static constexpr int stride = K_inner ? K : N;
...@@ -68,26 +69,28 @@ struct OperandTraits { ...@@ -68,26 +69,28 @@ struct OperandTraits {
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> { typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> { typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> { typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -96,8 +99,8 @@ struct OperandTraits<16, N, K, false, ...@@ -96,8 +99,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> { typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
...@@ -106,26 +109,28 @@ struct OperandTraits<16, N, K, false, ...@@ -106,26 +109,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> { typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> { typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> { typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -134,8 +139,8 @@ struct OperandTraits<32, N, K, false, ...@@ -134,8 +139,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> { typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
...@@ -144,8 +149,8 @@ struct OperandTraits<32, N, K, false, ...@@ -144,8 +149,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> { typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
...@@ -153,17 +158,18 @@ struct OperandTraits<8, N, K, true, ...@@ -153,17 +158,18 @@ struct OperandTraits<8, N, K, true,
using Copy = SM75_U32x4_LDSM_N; using Copy = SM75_U32x4_LDSM_N;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> { typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{})); Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> { typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{})); Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
...@@ -171,8 +177,8 @@ struct OperandTraits<64, N, K, true, ...@@ -171,8 +177,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> { typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{})); Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
...@@ -194,12 +200,13 @@ public: ...@@ -194,12 +200,13 @@ public:
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits = using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>; OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits = using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>; OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>; using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace cute { namespace cute {
template <typename A_type, typename B_type, typename C_type, int num_warp_m, template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n> int num_warp_n, int N>
struct DispatchInstruction; struct DispatchInstruction;
using _X = Underscore; using _X = Underscore;
...@@ -106,58 +106,61 @@ template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E5M2_TN> { ...@@ -106,58 +106,61 @@ template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E5M2_TN> {
using CLayout = SM80_16x8_Row; using CLayout = SM80_16x8_Row;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E4M3_TN>; using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E4M3_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E5M2_TN>; using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E5M2_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m, struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m, struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> { struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> { struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
}; };
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
}; };
#endif #endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void> template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits { struct OperandTraits {
// Primary template, use padded layout and default copy // Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N; static constexpr int stride = K_inner ? K : N;
...@@ -169,26 +172,28 @@ struct OperandTraits { ...@@ -169,26 +172,28 @@ struct OperandTraits {
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> { typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> { typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> { typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -197,8 +202,8 @@ struct OperandTraits<16, N, K, false, ...@@ -197,8 +202,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> { typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
...@@ -207,26 +212,28 @@ struct OperandTraits<16, N, K, false, ...@@ -207,26 +212,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> { typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> { typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> { typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -235,8 +242,8 @@ struct OperandTraits<32, N, K, false, ...@@ -235,8 +242,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> { typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
...@@ -245,26 +252,28 @@ struct OperandTraits<32, N, K, false, ...@@ -245,26 +252,28 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> { typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> { typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{})); Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> { typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{})); Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
...@@ -272,8 +281,8 @@ struct OperandTraits<64, N, K, true, ...@@ -272,8 +281,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> { typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{})); Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
...@@ -295,12 +304,13 @@ public: ...@@ -295,12 +304,13 @@ public:
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits = using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>; OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits = using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>; OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>; using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
...@@ -167,53 +167,54 @@ public: ...@@ -167,53 +167,54 @@ public:
namespace tl_mma { namespace tl_mma {
template <typename A_type, typename B_type, typename C_type, int num_warp_m, template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n> int num_warp_n, int N>
struct DispatchInstruction; struct DispatchInstruction;
using _X = Underscore; using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m, struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m, struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> { num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> { struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> { struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>; using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
}; };
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> { struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
}; };
#endif #endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void> template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits { struct OperandTraits {
// Primary template, use padded layout and default copy // Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N; static constexpr int stride = K_inner ? K : N;
...@@ -225,26 +226,28 @@ struct OperandTraits { ...@@ -225,26 +226,28 @@ struct OperandTraits {
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> { typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> { typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> { typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -253,8 +256,8 @@ struct OperandTraits<16, N, K, false, ...@@ -253,8 +256,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> { typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{})); Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
...@@ -263,26 +266,28 @@ struct OperandTraits<16, N, K, false, ...@@ -263,26 +266,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T; using Copy = SM75_U16x8_LDSM_T;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> { typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> { typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> { typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{})); Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
...@@ -291,8 +296,8 @@ struct OperandTraits<32, N, K, false, ...@@ -291,8 +296,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> { typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{})); Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
...@@ -301,26 +306,28 @@ struct OperandTraits<32, N, K, false, ...@@ -301,26 +306,28 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>; using Copy = UniversalCopy<tfloat32_t>;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> { typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{})); Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> { typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{})); Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N; using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> { typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{})); Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
...@@ -328,8 +335,8 @@ struct OperandTraits<64, N, K, true, ...@@ -328,8 +335,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <int N, int K> template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> { typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition( using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{})); Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
...@@ -351,12 +358,12 @@ public: ...@@ -351,12 +358,12 @@ public:
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits = using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>; OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits = using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>; OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>; using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment