Commit 896408a5 authored by letaoqin's avatar letaoqin
Browse files

fix group gemm

parent 95d76f67
...@@ -287,8 +287,8 @@ int run(int argc, char* argv[]) ...@@ -287,8 +287,8 @@ int run(int argc, char* argv[])
p_c, p_c,
p_z, p_z,
p_lse, p_lse,
nullptr, // p_acc0_biases {}, // p_acc0_biases
nullptr, // p_acc1_biases {}, // p_acc1_biases
problem_descs, problem_descs,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
......
...@@ -24,6 +24,7 @@ namespace tensor_operation { ...@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename D0DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename GroupKernelArg, typename GroupKernelArg,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -99,8 +100,17 @@ __global__ void ...@@ -99,8 +100,17 @@ __global__ void
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -109,9 +119,7 @@ __global__ void ...@@ -109,9 +119,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_d0_grid_ == nullptr tmp_p_d0_grid,
? nullptr
: arg_ptr[group_id].p_d0_grid_ + d0_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr arg_ptr[group_id].p_z_grid_ == nullptr
...@@ -150,9 +158,7 @@ __global__ void ...@@ -150,9 +158,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_d0_grid_ == nullptr tmp_p_d0_grid,
? nullptr
: arg_ptr[group_id].p_d0_grid_ + d0_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
...@@ -716,9 +722,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -716,9 +722,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const D0GridDesc_M_N d0_grid_desc_m_n{
DeviceOp::MakeD0GridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths, std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
problem_desc.acc0_biases_gs_ms_ns_strides)}; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n); d0_grid_desc_m_n);
...@@ -735,9 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -735,9 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
DeviceOp::MakeD0GridDescriptor_G_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
problem_desc.acc0_biases_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
...@@ -812,10 +831,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -812,10 +831,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for check // for check
std::vector<ck::index_t> d0_n_length_stride; std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back( d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
problem_desc.acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
d0_n_length_stride.push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
...@@ -900,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -900,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType,
GemmAccDataType, GemmAccDataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment