Commit 896408a5 authored by letaoqin's avatar letaoqin
Browse files

fix group gemm

parent 95d76f67
......@@ -287,8 +287,8 @@ int run(int argc, char* argv[])
p_c,
p_z,
p_lse,
nullptr, // p_acc0_biases
nullptr, // p_acc1_biases
{}, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
......
......@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename D0DataType,
typename GemmAccDataType,
typename GroupKernelArg,
typename AElementwiseOperation,
......@@ -99,8 +100,17 @@ __global__ void
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
{
......@@ -109,9 +119,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_d0_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_d0_grid_ + d0_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
......@@ -150,9 +158,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_d0_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_d0_grid_ + d0_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
......@@ -716,9 +722,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const D0GridDesc_M_N d0_grid_desc_m_n{
DeviceOp::MakeD0GridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides)};
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n);
......@@ -735,9 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n =
DeviceOp::MakeD0GridDescriptor_G_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
......@@ -812,10 +831,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(
problem_desc.acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
......@@ -900,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment