Commit b8f08e67 authored by letaoqin's avatar letaoqin
Browse files

group change to multiple D

parent fa066d60
......@@ -95,13 +95,19 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetDBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
typename GridwiseGemm::D0sGridPointer p_d0s_grid = arg_ptr[group_id].p_d0s_grid_;
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
......@@ -109,11 +115,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
......@@ -129,9 +133,9 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
......@@ -150,10 +154,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_d_grid_ + d_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
......@@ -168,9 +171,9 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
......@@ -300,12 +303,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumD0Tensor <= 1, "Bias0 addition is only support one bias");
static_assert(NumD1Tensor == 0, "Bias addition is unimplemented");
static_assert(NumD0Tensor == 0
? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType;
#if 0
// TODO ANT: use alias
......@@ -424,19 +422,43 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
}
}
static auto MakeD0sGridDescriptor_M_N(
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
......@@ -460,16 +482,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const DGridDesc_G_M_N& d_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
d_grid_desc_g_m_n_(d_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{
......@@ -495,9 +517,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx) const
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{
return d_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
......@@ -513,9 +537,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
DGridDesc_G_M_N d_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
......@@ -524,6 +548,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
......@@ -538,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
ZGridDesc_M_N,
......@@ -599,21 +625,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
const DDataType* p_d_grid_;
ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
DGridDesc_M_N d_grid_desc_m_n_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_;
......@@ -650,7 +675,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CGridDesc_M_N c_grid_desc_m_n_;
// raw data
int raw_d0_n_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
};
// Argument
......@@ -700,9 +725,21 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_d_grid = NumD0Tensor == 0
? nullptr
: static_cast<const DDataType*>(p_acc0_biases_vec[i][0]);
const auto& problem_desc = problem_desc_vec[i];
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
static_for<0, NumD0Tensor, 1>{}([&](auto j) {
using D0DataType = remove_cvref_t<tuple_element_t<j.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid(j) = static_cast<const D0DataType*>(p_acc0_biases_vec[i][j]);
// for check
d0s_nl_ns_lengths_strides[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_lengths[j][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[j][NumDimG + NumDimM]);
});
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
......@@ -711,24 +748,22 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
is_lse_storing_ = false;
}
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const D0sGridDesc_M_N d0s_grid_desc_m_n{
DeviceOp::MakeD0sGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides)};
const auto d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_m_n =
NumD0Tensor == 0
? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d_grid_desc_m_n);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto lse_grid_desc_m =
......@@ -742,11 +777,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto d_grid_desc_g_m_n =
NumD0Tensor == 0 ? DGridDesc_G_M_N{}
: Transform::MakeCGridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
......@@ -771,12 +801,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
const auto d0s_grid_desc_g_m_n = DeviceOp::MakeD0sGridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths,
problem_desc.acc0_biases_gs_ms_ns_strides);
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k,
b_grid_desc_g_n_k,
d0s_grid_desc_g_m_n,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
d_grid_desc_g_m_n,
z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
......@@ -805,17 +838,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_d0s_grid,
p_b1_grid,
p_c_grid,
p_d_grid,
p_z_grid,
p_lse_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m_n,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n,
lse_grid_desc_m,
......@@ -832,11 +864,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
auto raw_d0_m_n = NumD0Tensor == 0
? RawTransform::MakeCGridDescriptor_M_N({}, {})
: RawTransform::MakeCGridDescriptor_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths[0],
problem_desc.acc0_biases_gs_ms_ns_strides[0]);
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
......@@ -851,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n,
NumD0Tensor == 0 ? 0 : raw_d0_m_n.GetLength(I1)});
d0s_nl_ns_lengths_strides});
}
is_dropout_ = p_dropout > 0.0; //
......@@ -1067,11 +1094,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(device_arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment