Unverified Commit 4033f5df authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #949 from ROCmSoftwarePlatform/mha-train-mqagqa

Grouped Query Attention/Multi Query Attention
parents 18be6bc9 957d5dee
...@@ -43,6 +43,7 @@ __global__ void ...@@ -43,6 +43,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -87,13 +88,14 @@ __global__ void ...@@ -87,13 +88,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -147,6 +149,7 @@ __global__ void ...@@ -147,6 +149,7 @@ __global__ void
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = h_ratio;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op; ignore = acc_element_op;
...@@ -367,7 +370,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -367,7 +370,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_strides);
} }
...@@ -376,7 +378,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -376,7 +378,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeC0GridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_strides);
} }
...@@ -606,6 +607,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -606,6 +607,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check // for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
// raw data // raw data
std::vector<ck::index_t> d0_n_length_stride_; std::vector<ck::index_t> d0_n_length_stride_;
...@@ -654,6 +657,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -654,6 +657,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
...@@ -805,6 +811,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -805,6 +811,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n, c_grid_desc_m_n,
b_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d0_n_length_stride}); d0_n_length_stride});
} }
...@@ -830,6 +838,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -830,6 +838,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t h_ratio_;
float p_dropout_; float p_dropout_;
uint8_t p_dropout_in_uint8_t_; uint8_t p_dropout_in_uint8_t_;
unsigned long long seed_; unsigned long long seed_;
...@@ -918,6 +927,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -918,6 +927,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_, arg.group_count_,
arg.h_ratio_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -1040,11 +1050,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1040,11 +1050,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment