"vscode:/vscode.git/clone" did not exist on "056bea45a171df682a75943d6733a3c6a9d1b173"
Unverified Commit 4033f5df authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #949 from ROCmSoftwarePlatform/mha-train-mqagqa

Grouped Query Attention/Multi Query Attention
parents 18be6bc9 957d5dee
......@@ -43,6 +43,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
......@@ -87,13 +88,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -147,6 +149,7 @@ __global__ void
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = h_ratio;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
......@@ -367,7 +370,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -376,7 +378,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeC0GridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
......@@ -606,6 +607,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
......@@ -654,6 +657,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
......@@ -805,6 +811,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n,
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d0_n_length_stride});
}
......@@ -830,6 +838,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t h_ratio_;
float p_dropout_;
uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_;
......@@ -918,6 +927,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_,
arg.h_ratio_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
......@@ -1040,11 +1050,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment