Commit f3e61c0a authored by danyao12's avatar danyao12
Browse files

datatype of bwd output can be selected

parent f7e05f9e
......@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename DataType,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename AElementwiseOperation,
......@@ -53,16 +54,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -171,7 +172,8 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
......@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct Argument : public BaseArgument
{
Argument(
const DataType* p_a_grid,
const DataType* p_b_grid,
const InputDataType* p_a_grid,
const InputDataType* p_b_grid,
ZDataType* p_z_grid,
const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS
const InputDataType* p_b1_grid,
const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
DataType* p_vgrad_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
DataType,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
AElementwiseOperation,
......@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
static auto MakeArgument(
const DataType* p_a,
const DataType* p_b,
const InputDataType* p_a,
const InputDataType* p_b,
ZDataType* p_z,
const DataType* p_b1,
const DataType* p_c,
const InputDataType* p_b1,
const InputDataType* p_c,
const LSEDataType* p_lse,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const InputDataType*>(p_b1),
static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid),
static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
......
......@@ -27,7 +27,8 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename DataType,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename AElementwiseOperation,
......@@ -52,16 +53,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -170,7 +171,8 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
......@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct Argument : public BaseArgument
{
Argument(
const DataType* p_a_grid,
const DataType* p_b_grid,
const InputDataType* p_a_grid,
const InputDataType* p_b_grid,
ZDataType* p_z_grid,
const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS
const InputDataType* p_b1_grid,
const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
DataType* p_vgrad_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
AElementwiseOperation,
......@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
static auto MakeArgument(
const DataType* p_a,
const DataType* p_b,
const InputDataType* p_a,
const InputDataType* p_b,
ZDataType* p_z,
const DataType* p_b1,
const DataType* p_c,
const InputDataType* p_b1,
const InputDataType* p_c,
const LSEDataType* p_lse,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
......@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const InputDataType*>(p_b1),
static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid),
static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
......
......@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
......@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GroupKernelArg
{
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
DataType* p_vgrad_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]);
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......
......@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
......@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GroupKernelArg
{
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_;
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_;
DataType* p_vgrad_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]);
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......
......@@ -20,7 +20,8 @@
namespace ck {
template <typename DataType,
template <typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1),
......@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
......@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1),
......@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1),
......@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation
......@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
......@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
......@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
......@@ -20,7 +20,8 @@
namespace ck {
template <typename DataType,
template <typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
......@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
......@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
......@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BBlockSliceLengths,
typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType,
InputDataType,
GemmDataType,
GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1),
......@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation
......@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
......@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
......@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment