"...composable_kernel_rocm.git" did not exist on "e1cd41215586ef9fa80be6d3372deeb920e7fb65"
Unverified Commit 78f637e4 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge pull request #58 from ROCm/navi4x_conv_fwd

Navi4x Conv and MHA enablement
parents 7e147c64 5cb59d36
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103) list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
......
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
...@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA ...@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -191,9 +194,6 @@ struct BlockwiseGemmWMMA ...@@ -191,9 +194,6 @@ struct BlockwiseGemmWMMA
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
static_assert(AEnableLds == true, "only support EnableLds");
static_assert(BEnableLds == true, "only support EnableLds");
} }
// transposed WMMA output C' = B' * A' // transposed WMMA output C' = B' * A'
...@@ -316,7 +316,7 @@ struct BlockwiseGemmWMMA ...@@ -316,7 +316,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -326,7 +326,8 @@ struct BlockwiseGemmWMMA ...@@ -326,7 +326,8 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0), make_tuple(
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -372,7 +373,7 @@ struct BlockwiseGemmWMMA ...@@ -372,7 +373,7 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -380,7 +381,7 @@ struct BlockwiseGemmWMMA ...@@ -380,7 +381,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -442,44 +443,30 @@ struct BlockwiseGemmWMMA ...@@ -442,44 +443,30 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds> using AThreadCopyType =
struct AThreadCopySelector; ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
template <> decltype(a_block_desc_k0_m0_m1_m2_k1),
struct AThreadCopySelector<true> decltype(a_thread_desc_),
{ Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
using type = Sequence<0, 1, 2, 3, 4, 5>,
ThreadwiseTensorSliceTransfer_v4<FloatA, 5,
FloatA, A_K1,
decltype(a_block_desc_k0_m0_m1_m2_k1), A_K1>;
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, using BThreadCopyType =
Sequence<0, 1, 2, 3, 4, 5>, ThreadwiseTensorSliceTransfer_v4<FloatB,
5, FloatB,
A_K1, decltype(b_block_desc_k0_n0_n1_n2_k1),
A_K1>; decltype(b_thread_desc_),
}; Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
template <bool EnableLds> 5,
struct BThreadCopySelector; B_K1,
B_K1>;
template <>
struct BThreadCopySelector<true> AThreadCopyType a_thread_copy_;
{ BThreadCopyType b_thread_copy_;
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
}; };
#else #else
template <index_t BlockSize, template <index_t BlockSize,
...@@ -537,9 +524,8 @@ struct BlockwiseGemmWMMA ...@@ -537,9 +524,8 @@ struct BlockwiseGemmWMMA
// permutation // permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2; static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2; static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
......
...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -56,7 +56,7 @@ __global__ void ...@@ -56,7 +56,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -159,6 +159,7 @@ __global__ void ...@@ -159,6 +159,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -187,7 +188,7 @@ __global__ void ...@@ -187,7 +188,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -61,7 +61,7 @@ __global__ void ...@@ -61,7 +61,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -166,6 +166,7 @@ __global__ void ...@@ -166,6 +166,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -60,7 +60,8 @@ __global__ void ...@@ -60,7 +60,8 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx11__) || \
defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -165,6 +166,7 @@ __global__ void ...@@ -165,6 +166,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -594,7 +596,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -594,7 +596,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -571,15 +571,12 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -571,15 +571,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat // *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{ {
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{}, Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1, I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{})); Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
} }
...@@ -799,8 +796,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -799,8 +796,9 @@ struct GridwiseGemmMultipleD_Wmma
const auto M = e_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n, e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
......
...@@ -522,12 +522,6 @@ struct GridwiseGemm_Wmma ...@@ -522,12 +522,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -559,6 +553,12 @@ struct GridwiseGemm_Wmma ...@@ -559,6 +553,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned * sizeof(BDataType)); b_block_space_size_aligned * sizeof(BDataType));
}; };
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment