"...hubert/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9877f54491a7081266207e1a999dd47bc2bba17e"
Commit d1c15bc5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support cute mma tile mxn8ky (#434)

* [Enhancement] Improve error handling in layout inference and update profiler type in tests

* Added a detailed error message in the layout inference for local.fragment to clarify the requirement for trans_B.
* Updated the profiler type in the cumulative sum test from TensorSupplyType.One to TensorDistributionType.Randn for better profiling accuracy.

* lint fix

* [Refactor] Update OperandTraits to include num_warp_n parameter

* Modified OperandTraits templates across gemm_sm80.h, gemm_sm89.h, and gemm_sm90.h to include an additional num_warp_n parameter for improved flexibility in layout and copy operations.
* Adjusted Copy type selection based on the new parameter to enhance performance and adaptability in various scenarios.

* lint fix

* [Refactor] Update DispatchInstruction templates to include N parameter

* Modified DispatchInstruction templates in gemm_sm80.h, gemm_sm89.h, and gemm_sm90.h to include an additional N parameter, enhancing flexibility in tile size calculations.
* Adjusted MMA_Group definitions to use std::min for improved handling of warp sizes, ensuring better performance and adaptability in various scenarios.
parent 3d206235
......@@ -213,7 +213,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false);
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this";
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
} else {
ICHECK(0);
......
......@@ -10,53 +10,54 @@
namespace cute {
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n>
int num_warp_n, int N>
struct DispatchInstruction;
using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
#endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
......@@ -68,26 +69,28 @@ struct OperandTraits {
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -96,8 +99,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
......@@ -106,26 +109,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -134,8 +139,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
......@@ -144,8 +149,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
......@@ -153,17 +158,18 @@ struct OperandTraits<8, N, K, true,
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<64, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
......@@ -171,8 +177,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
......@@ -194,12 +200,13 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
......@@ -11,7 +11,7 @@
namespace cute {
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n>
int num_warp_n, int N>
struct DispatchInstruction;
using _X = Underscore;
......@@ -106,58 +106,61 @@ template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E5M2_TN> {
using CLayout = SM80_16x8_Row;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E4M3_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E5M2_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
#endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
......@@ -169,26 +172,28 @@ struct OperandTraits {
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -197,8 +202,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
......@@ -207,26 +212,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -235,8 +242,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
......@@ -245,26 +252,28 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<64, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
......@@ -272,8 +281,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
......@@ -295,12 +304,13 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
......@@ -167,53 +167,54 @@ public:
namespace tl_mma {
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n>
int num_warp_n, int N>
struct DispatchInstruction;
using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> {
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
#endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
......@@ -225,26 +226,28 @@ struct OperandTraits {
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -253,8 +256,8 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
......@@ -263,26 +266,28 @@ struct OperandTraits<16, N, K, false,
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
......@@ -291,8 +296,8 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
......@@ -301,26 +306,28 @@ struct OperandTraits<32, N, K, false,
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<8, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
SM75_U32x4_LDSM_N>::type;
};
template <int N, int K>
struct OperandTraits<64, N, K, true,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
......@@ -328,8 +335,8 @@ struct OperandTraits<64, N, K, true,
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false,
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom = decltype(composition(
Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
......@@ -351,12 +358,12 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment