"...composable_kernel_rocm.git" did not exist on "88833bd9ad99721fdc9f636e096710acf7e0b14f"
Commit f3e61c0a authored by danyao12's avatar danyao12
Browse files

datatype of bwd output can be selected

parent f7e05f9e
...@@ -28,7 +28,8 @@ namespace tensor_operation { ...@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename InputDataType,
typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -53,16 +54,16 @@ __global__ void ...@@ -53,16 +54,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -171,7 +172,8 @@ template <index_t NumDimG, ...@@ -171,7 +172,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(
const DataType* p_a_grid, const InputDataType* p_a_grid,
const DataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const DataType* p_b1_grid, const InputDataType* p_b1_grid,
const DataType* p_c_grid, // for dS const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid, const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, InputDataType,
OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
static auto MakeArgument( static auto MakeArgument(
const DataType* p_a, const InputDataType* p_a,
const DataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const DataType* p_b1, const InputDataType* p_b1,
const DataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse), static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
......
...@@ -27,7 +27,8 @@ namespace tensor_operation { ...@@ -27,7 +27,8 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename InputDataType,
typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -52,16 +53,16 @@ __global__ void ...@@ -52,16 +53,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -170,7 +171,8 @@ template <index_t NumDimG, ...@@ -170,7 +171,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(
const DataType* p_a_grid, const InputDataType* p_a_grid,
const DataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const DataType* p_b1_grid, const InputDataType* p_b1_grid,
const DataType* p_c_grid, // for dS const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid, const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, InputDataType,
OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
static auto MakeArgument( static auto MakeArgument(
const DataType* p_a, const InputDataType* p_a,
const DataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const DataType* p_b1, const InputDataType* p_b1,
const DataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse), static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
......
...@@ -150,7 +150,8 @@ template <index_t NumDimG, ...@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GroupKernelArg struct GroupKernelArg
{ {
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
grid_size_ = 0; grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]); const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]); auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
......
...@@ -150,7 +150,8 @@ template <index_t NumDimG, ...@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GroupKernelArg struct GroupKernelArg
{ {
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
grid_size_ = 0; grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]); const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]); auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename DataType, template <typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1), decltype(q_block_desc_k0_m_k1),
...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1), decltype(k_block_desc_k0_n_k1),
...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1), decltype(v_block_desc_k0_n_k1),
...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1), decltype(ygrad_block_desc_k0_m_k1),
...@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename ElementwiseOp = tensor_operation::element_wise::PassThrough> typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4, CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation ElementwiseOp, // CElementwiseOperation
...@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y // performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, InputDataType,
FloatGemmAcc, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
...@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
DataType, // typename DstData, OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock), decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename DataType, template <typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<B1K0, Gemm1NPerBlock, B1K1>, Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BBlockSliceLengths, typename Gemm2Params_N_O_M::BBlockSliceLengths,
typename Gemm2Params_N_O_M::BThreadClusterLengths, typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_M0_O_M1, GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
...@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename ElementwiseOp = tensor_operation::element_wise::PassThrough> typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4, CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation ElementwiseOp, // CElementwiseOperation
...@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, InputDataType,
FloatGemmAcc, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
...@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
DataType, // typename DstData, OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock), decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment