Commit b8f08e67 authored by letaoqin's avatar letaoqin
Browse files

group change to multiple D

parent fa066d60
...@@ -95,13 +95,19 @@ __global__ void ...@@ -95,13 +95,19 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetDBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
typename GridwiseGemm::D0sGridPointer p_d0s_grid = arg_ptr[group_id].p_d0s_grid_;
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
...@@ -109,11 +115,9 @@ __global__ void ...@@ -109,11 +115,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr arg_ptr[group_id].p_z_grid_ == nullptr
? nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
...@@ -129,9 +133,9 @@ __global__ void ...@@ -129,9 +133,9 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
...@@ -150,10 +154,9 @@ __global__ void ...@@ -150,10 +154,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr arg_ptr[group_id].p_lse_grid_ == nullptr
...@@ -168,9 +171,9 @@ __global__ void ...@@ -168,9 +171,9 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
...@@ -300,12 +303,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -300,12 +303,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size(); static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumD0Tensor <= 1, "Bias0 addition is only support one bias");
static_assert(NumD1Tensor == 0, "Bias addition is unimplemented"); static_assert(NumD1Tensor == 0, "Bias addition is unimplemented");
static_assert(NumD0Tensor == 0
? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType;
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -424,20 +422,44 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -424,20 +422,44 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
} }
} }
static auto MakeD0sGridDescriptor_M_N(
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using DGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -460,16 +482,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -460,16 +482,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const DGridDesc_G_M_N& d_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
d_grid_desc_g_m_n_(d_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
...@@ -495,9 +517,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -495,9 +517,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx) const template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{ {
return d_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
...@@ -513,9 +537,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -513,9 +537,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
DGridDesc_G_M_N d_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
...@@ -524,6 +548,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -524,6 +548,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
...@@ -538,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -538,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
...@@ -599,21 +625,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -599,21 +625,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const DDataType* p_d_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_; LSEDataType* p_lse_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
DGridDesc_M_N d_grid_desc_m_n_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
...@@ -650,7 +675,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -650,7 +675,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
// raw data // raw data
int raw_d0_n_; std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
}; };
// Argument // Argument
...@@ -700,9 +725,21 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -700,9 +725,21 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_d_grid = NumD0Tensor == 0
? nullptr const auto& problem_desc = problem_desc_vec[i];
: static_cast<const DDataType*>(p_acc0_biases_vec[i][0]); std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
static_for<0, NumD0Tensor, 1>{}([&](auto j) {
using D0DataType = remove_cvref_t<tuple_element_t<j.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid(j) = static_cast<const D0DataType*>(p_acc0_biases_vec[i][j]);
// for check
d0s_nl_ns_lengths_strides[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_lengths[j][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[j][NumDimG + NumDimM]);
});
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]); const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
...@@ -711,24 +748,22 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -711,24 +748,22 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
is_lse_storing_ = false; is_lse_storing_ = false;
} }
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const D0sGridDesc_M_N d0s_grid_desc_m_n{
DeviceOp::MakeD0sGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides)};
const auto d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_m_n =
NumD0Tensor == 0
? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d_grid_desc_m_n);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto lse_grid_desc_m = const auto lse_grid_desc_m =
...@@ -742,11 +777,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -742,11 +777,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_g_m_n =
NumD0Tensor == 0 ? DGridDesc_G_M_N{}
: Transform::MakeCGridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
...@@ -771,12 +801,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -771,12 +801,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
const auto d0s_grid_desc_g_m_n = DeviceOp::MakeD0sGridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides);
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, a_grid_desc_g_m_k,
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
d0s_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
d_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
...@@ -805,17 +838,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -805,17 +838,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_d0s_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_d_grid,
p_z_grid, p_z_grid,
p_lse_grid, p_lse_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m_n,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n, z_grid_desc_m_n,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -832,11 +864,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -832,11 +864,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
auto raw_d0_m_n = NumD0Tensor == 0
? RawTransform::MakeCGridDescriptor_M_N({}, {})
: RawTransform::MakeCGridDescriptor_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -851,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -851,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n, c_grid_desc_m_n,
NumD0Tensor == 0 ? 0 : raw_d0_m_n.GetLength(I1)}); d0s_nl_ns_lengths_strides});
} }
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
...@@ -1067,11 +1094,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1067,11 +1094,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(device_arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment