Commit fa066d60 authored by letaoqin's avatar letaoqin
Browse files

gridwise change to multiple D

parent ee275d4d
...@@ -25,8 +25,8 @@ namespace device { ...@@ -25,8 +25,8 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename D0sPointer,
typename FloatC, typename FloatC,
typename DDataType,
typename ZDataType, typename ZDataType,
typename FloatLSE, typename FloatLSE,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -37,9 +37,9 @@ template <typename GridwiseGemm, ...@@ -37,9 +37,9 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -56,9 +56,9 @@ __global__ void ...@@ -56,9 +56,9 @@ __global__ void
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2( kernel_batched_multiheadattention_forward_xdl_cshuffle_v2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
D0sPointer p_d0s_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const DDataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -68,11 +68,11 @@ __global__ void ...@@ -68,11 +68,11 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
...@@ -103,13 +103,15 @@ __global__ void ...@@ -103,13 +103,15 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetDBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
// const index_t global_thread_id = get_thread_global_1d_id(); // const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
...@@ -122,9 +124,9 @@ __global__ void ...@@ -122,9 +124,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0s_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_d_grid == nullptr ? nullptr : p_d_grid + d_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
...@@ -135,9 +137,9 @@ __global__ void ...@@ -135,9 +137,9 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
...@@ -155,9 +157,9 @@ __global__ void ...@@ -155,9 +157,9 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0s_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_d_grid == nullptr ? nullptr : p_d_grid + d_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
...@@ -168,9 +170,9 @@ __global__ void ...@@ -168,9 +170,9 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
...@@ -318,12 +320,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -318,12 +320,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size(); static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumD0Tensor <= 1, "Acc0 Bias addition is max support one bias");
static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented"); static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented");
static_assert(NumD1Tensor == 0
? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType;
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -415,19 +412,43 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -415,19 +412,43 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
} }
} }
static auto MakeD0sGridDescriptor_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -450,16 +471,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -450,16 +471,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const DGridDesc_G_M_N& d_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
d_grid_desc_g_m_n_(d_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
...@@ -485,9 +506,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -485,9 +506,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx) const template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{ {
return d_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
...@@ -503,9 +526,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -503,9 +526,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
DGridDesc_G_M_N d_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -513,6 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -513,6 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
...@@ -527,6 +551,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -527,6 +551,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
...@@ -622,8 +647,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -622,8 +647,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_d_grid_{NumD0Tensor == 0 ? nullptr
: static_cast<const DDataType*>(p_acc0_biases[0])},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_lse_grid_{p_lse_grid}, p_lse_grid_{p_lse_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
...@@ -634,24 +657,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -634,24 +657,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_m_n_{NumD0Tensor == 0
? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{ b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_g_m_n_{NumD0Tensor == 0 ? DGridDesc_G_M_N{}
: Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
...@@ -678,12 +695,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -678,12 +695,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_, b_grid_desc_g_n_k_,
d0s_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
d_grid_desc_g_m_n_,
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
raw_d0_n_(0)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc1_biases; ignore = p_acc1_biases;
...@@ -699,8 +715,25 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -699,8 +715,25 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n);
} }
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
});
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
...@@ -710,8 +743,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -710,8 +743,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(d_grid_desc_m_n_);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
...@@ -722,12 +753,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -722,12 +753,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
} }
if constexpr(NumD0Tensor)
{
const auto d0_grid_desc_m_n = RawTransform::MakeCGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths[0], acc0_biases_gs_ms_ns_strides[0]);
raw_d0_n_ = d0_grid_desc_m_n.GetLength(I1);
}
} }
void Print() const void Print() const
...@@ -751,7 +776,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -751,7 +776,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const DDataType* p_d_grid_; typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_; LSEDataType* p_lse_grid_;
...@@ -760,21 +785,20 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -760,21 +785,20 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M_N d_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
DGridDesc_G_M_N d_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
...@@ -815,7 +839,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -815,7 +839,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
index_t n_raw_padded_; index_t n_raw_padded_;
// raw data // raw data
int raw_d0_n_; std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
}; };
// Invoker // Invoker
...@@ -846,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -846,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::D0sGridPointer,
CDataType, CDataType,
DDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
GemmAccDataType, GemmAccDataType,
...@@ -858,10 +882,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -858,10 +882,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -879,9 +903,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -879,9 +903,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_d0s_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_d_grid_,
arg.p_z_grid_, arg.p_z_grid_,
arg.p_lse_grid_, arg.p_lse_grid_,
arg.a_element_op_, arg.a_element_op_,
...@@ -891,9 +915,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -891,9 +915,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
...@@ -1022,10 +1046,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1022,10 +1046,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return false; return false;
} }
if(arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0) for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
......
...@@ -25,6 +25,7 @@ namespace ck { ...@@ -25,6 +25,7 @@ namespace ck {
* *
*/ */
template <typename FloatAB, template <typename FloatAB,
typename D0sDataType,
typename ZDataType, typename ZDataType,
typename FloatGemm, typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
...@@ -39,6 +40,7 @@ template <typename FloatAB, ...@@ -39,6 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0sGridDesc_M_N,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
...@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -99,7 +101,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"); "D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
using DDataType = FloatAB; static constexpr index_t NumD0Tensor = D0sDataType::Size();
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -414,6 +416,53 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_m_n); c_grid_desc_m_n);
} }
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
...@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -475,9 +524,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
typename C0MatrixMask> typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
D0sGridPointer p_d0s_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const DDataType* __restrict__ p_d_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -488,11 +537,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -907,7 +956,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d matrix) // bias (d matrix)
constexpr auto d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId I1, // NBlockId
m0, // MRepeat m0, // MRepeat
...@@ -919,11 +968,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -919,11 +968,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
n3, // NInputNum n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
auto d_threadwise_copy_globla_vgpr = auto d0s_threadwise_copy = generate_tuple(
ThreadwiseTensorSliceTransfer_v2<DDataType, [&](auto i) {
DDataType, using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
decltype(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), return ThreadwiseTensorSliceTransfer_v2<
decltype(d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), D0DataType,
D0DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
...@@ -938,7 +990,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -938,7 +990,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
9, 9,
D0BlockTransferSrcScalarPerVector, D0BlockTransferSrcScalarPerVector,
1, 1,
false>(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
...@@ -949,6 +1001,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -949,6 +1001,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0)); // register number 0)); // register number
},
Number<NumD0Tensor>{});
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
// z is random number matrix for dropout verify // z is random number matrix for dropout verify
// //
...@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1219,33 +1281,30 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if(p_d_grid) static_for<0, NumD0Tensor, 1>{}([&](auto i) {
{ // get register
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
p_d_grid, d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType, D0DataType,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
d_thread_buf; d0_thread_buf;
d_threadwise_copy_globla_vgpr.Run( // load data from global
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d_grid_buf, d0s_grid_buf[i],
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf); d0_thread_buf);
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// acc add bias // acc add bias
acc_thread_buf(i) += d_thread_buf[i]; static_for<0, m0 * n0 * n2 * n4, 1>{}(
}); [&](auto j) { acc_thread_buf(j) += d0_thread_buf[j]; });
d_threadwise_copy_globla_vgpr.MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} });
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment