Unverified Commit 4033f5df authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #949 from ROCmSoftwarePlatform/mha-train-mqagqa

Grouped Query Attention/Multi Query Attention
parents 18be6bc9 957d5dee
......@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -82,19 +83,26 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1GradBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
......@@ -128,9 +136,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -139,9 +147,11 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
......@@ -167,9 +177,9 @@ __global__ void
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + bgrad_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].p_vgrad_grid_ + b1grad_batch_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -178,9 +188,11 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].bgrad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1grad_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
......@@ -196,6 +208,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -313,6 +326,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> bgrad_gs_ns_ks_lengths;
std::vector<index_t> bgrad_gs_ns_ks_strides;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_lengths;
std::vector<index_t> b1grad_gs_gemm1ns_gemm1ks_strides;
std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_bias_gs_ms_ns_strides;
......@@ -557,7 +576,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -566,7 +584,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -613,6 +630,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
......@@ -620,6 +639,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE)
{
}
......@@ -659,6 +680,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
......@@ -666,6 +697,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_;
};
......@@ -765,9 +798,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
......@@ -802,6 +837,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_;
......@@ -870,6 +906,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b_gs_ns_ks_lengths[NumDimG - 1];
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
......@@ -897,6 +936,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto bgrad_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
......@@ -919,6 +960,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto b1grad_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
......@@ -942,6 +986,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto bgrad_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.bgrad_gs_ns_ks_lengths, problem_desc.bgrad_gs_ns_ks_strides);
const auto b1grad_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1grad_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1grad_gs_gemm1ns_gemm1ks_strides);
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
......@@ -975,6 +1024,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k,
type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask
......@@ -1002,9 +1053,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
......@@ -1042,6 +1095,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
batch_count,
d0_n_length_stride});
......@@ -1068,6 +1122,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t grid_size_;
index_t group_count_;
index_t h_ratio_;
std::vector<GroupKernelArg> group_kernel_args_;
std::vector<GroupDeviceArg> group_device_args_;
......@@ -1126,6 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_,
arg.h_ratio_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
......@@ -1194,13 +1250,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n &&
c_g % b_g == 0 && c_g / b_g == arg.h_ratio_))
{
return false;
}
......
......@@ -43,6 +43,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -87,13 +88,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -147,6 +149,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -367,7 +370,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -376,7 +378,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -606,6 +607,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
......@@ -654,6 +657,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
......@@ -805,6 +811,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n,
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d0_n_length_stride});
}
......@@ -830,6 +838,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t h_ratio_;
float p_dropout_;
uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_;
......@@ -918,6 +927,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_,
arg.h_ratio_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
......@@ -1040,11 +1050,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{
return false;
}
......
......@@ -1443,10 +1443,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const LSEGridDesc_M& lse_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map,
......@@ -1477,11 +1479,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
......@@ -1631,7 +1633,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
......@@ -1660,7 +1662,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
......
......@@ -1535,10 +1535,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const LSEGridDesc_M& lse_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
......@@ -1569,11 +1571,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
......@@ -1746,7 +1748,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
......@@ -1779,7 +1781,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
......
......@@ -1524,10 +1524,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
......@@ -1560,11 +1562,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
......@@ -1714,7 +1716,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
......@@ -1743,7 +1745,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
......
......@@ -1600,10 +1600,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const KGridDesc_K0_N_K1& kgrad_grid_desc_k0_n_k1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const VGridDesc_O0_N_O1& vgrad_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
......@@ -1636,11 +1638,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
p_vgrad_grid, vgrad_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
p_kgrad_grid, kgrad_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
......@@ -1813,7 +1815,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(vgrad_grid_desc_o0_n_o1);
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
......@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(kgrad_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment