Commit 825f7f02 authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor Gemm1

parent c798cff9
...@@ -108,140 +108,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -108,140 +108,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
// VGrad Gemm
template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_
{
static constexpr index_t Free0_N = NPerBlock;
static constexpr index_t Free1_O = Gemm1NPerBlock;
static constexpr index_t Sum_M = Sum_M_;
static constexpr index_t P_M1 = 8; // P will be row-major
static constexpr index_t P_M0 = Sum_M / P_M1;
static constexpr index_t P_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static constexpr index_t YGrad_M1 = 2; // dY assumed row-major, typically =2 for fp16
static constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
static constexpr index_t YGrad_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t YGrad_SrcVectorDim = 1; // Free1_O dimension
static constexpr index_t YGrad_SrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack =
math::max(math::lcm(P_M1, YGrad_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using YGrad_BlockSliceLengths = Sequence<YGrad_M0, Free1_O, YGrad_M1>;
using YGrad_ThreadClusterLengths =
Sequence<BlockSize / (Free1_O / YGrad_SrcScalarPerVector),
Free1_O / YGrad_SrcScalarPerVector,
1>;
using YGrad_ThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetPBlockDescriptor_M0_N_M1()
{
constexpr index_t P_M0 = Sum_M / P_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<P_M0>{}, Number<Free0_N>{}, Number<P_M1>{}),
make_tuple(Number<Free0_N + P_LdsPad>{} * Number<P_M1>{}, Number<P_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetYGradBlockDescriptor_M0_O_M1()
{
constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<YGrad_M0>{}, Number<Free1_O>{}, Number<YGrad_M1>{}),
make_tuple(
Number<Free1_O + YGrad_LdsPad>{} * Number<YGrad_M1>{}, Number<YGrad_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Sum_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t n = Free0_N - 1;
constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1()
{
return generate_sequence_v2(
[](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
Number<4>{});
}
};
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// QGrad Gemm
// KGrad Gemm
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -274,6 +145,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -274,6 +145,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
template <typename AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
const auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
const auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
...@@ -345,11 +243,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -345,11 +243,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1); const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
// This assumption redues implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
...@@ -446,137 +355,329 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -446,137 +355,329 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
struct SharedMemTrait // P / dP Gemm (type 1 rcr)
struct Gemm0
{ {
// LDS allocation for A and B: be careful of alignment // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 = static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
// Y / dQ Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I4);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I5);
static constexpr auto n3 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I6);
static constexpr auto n4 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I7);
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
static constexpr auto N3 = ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I6);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / n4 / N3>{};
static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{};
static constexpr auto AThreadSliceLength_K1 = Number<n4>{};
static constexpr auto acc_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{});
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
static constexpr auto b_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto ASrcScalarPerVector = n4;
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); FloatGemmAcc,
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( DataType,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); decltype(acc_thread_desc_k0_m_k1),
static constexpr auto p_block_space_size_aligned = decltype(a_thread_desc_k0_m_k1),
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align); tensor_operation::element_wise::PassThrough,
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple( AThreadSliceLengths_K0_M_K1,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align); Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>;
};
// dV / dK Gemm (type 3 crr)
template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_
{
static constexpr index_t Free0_N = NPerBlock;
static constexpr index_t Free1_O = Gemm1NPerBlock;
static constexpr index_t Sum_M = Sum_M_;
static constexpr index_t P_M1 = 8; // P will be row-major
static constexpr index_t P_M0 = Sum_M / P_M1;
static constexpr index_t P_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static constexpr index_t YGrad_M1 = 2; // dY assumed row-major, typically =2 for fp16
static constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
static constexpr index_t YGrad_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t YGrad_SrcVectorDim = 1; // Free1_O dimension
static constexpr index_t YGrad_SrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack =
math::max(math::lcm(P_M1, YGrad_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using YGrad_BlockSliceLengths = Sequence<YGrad_M0, Free1_O, YGrad_M1>;
using YGrad_ThreadClusterLengths =
Sequence<BlockSize / (Free1_O / YGrad_SrcScalarPerVector),
Free1_O / YGrad_SrcScalarPerVector,
1>;
using YGrad_ThreadClusterArrangeOrder = Sequence<0, 2, 1>;
static constexpr auto a_block_space_offset = 0; __host__ __device__ static constexpr auto GetPBlockDescriptor_M0_N_M1()
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; {
static constexpr auto b1_block_space_offset = 0; constexpr index_t P_M0 = Sum_M / P_M1;
static constexpr auto p_block_space_offset = 0; return make_naive_tensor_descriptor(
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value; make_tuple(Number<P_M0>{}, Number<Free0_N>{}, Number<P_M1>{}),
make_tuple(Number<Free0_N + P_LdsPad>{} * Number<P_M1>{}, Number<P_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetYGradBlockDescriptor_M0_O_M1()
{
constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<YGrad_M0>{}, Number<Free1_O>{}, Number<YGrad_M1>{}),
make_tuple(
Number<Free1_O + YGrad_LdsPad>{} * Number<YGrad_M1>{}, Number<YGrad_M1>{}, I1));
}
// LDS allocation for reduction __host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2()
static constexpr index_t reduction_space_size_aligned = {
math::integer_least_multiple(BlockSize, max_lds_align); // perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Sum_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
static constexpr auto reduction_space_offset = 0; // perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t n = Free0_N - 1;
constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave;
// LDS allocation for C shuffle in LDS // assume 256 decomposed into 2 x 4 x 32
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = // 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); // 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
static constexpr auto c_block_space_size = return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }
};
// P / dP Gemm (type 1 rcr) __host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1()
struct Gemm0 {
{ return generate_sequence_v2(
private: [](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
static constexpr auto a_block_desc = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); Number<4>{});
static constexpr auto b_block_desc = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); }
};
public: using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_M_K1,
decltype(a_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< struct YDotYGrad_M_O_
ThisThreadBlock, {
BElementwiseOperation, static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
tensor_operation::element_wise::PassThrough, static constexpr auto ThreadClusterLength_O =
InMemoryDataOperationEnum::Set, Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
Sequence<BK0, NPerBlock, BK1>, static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
BBlockTransferThreadClusterLengths_BK0_N_BK1, static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
BBlockTransferThreadClusterArrangeOrder, static constexpr auto ThreadSliceLength_M =
DataType, Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max( static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
// Blockwise gemm with transposed XDL output using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
using BlockwiseGemm = DataType,
BlockwiseGemmXdlops_v2<BlockSize, ThreadSliceLength_M * ThreadSliceLength_O,
DataType, true>;
FloatGemmAcc,
decltype(a_block_desc),
decltype(b_block_desc),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(
a_block_desc)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(
b_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); using DstBufType =
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
}; };
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major) // PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O struct PGradGemmTile_M_N_O
...@@ -631,13 +732,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -631,13 +732,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct QGradGemmTile_M_K_N struct QGradGemmTile_M_K_N
{ {
template <typename QGridDesc_K0_M_K1_> template <typename QGridDesc_K0_M_K1_>
__device__ static const auto __device__ static const auto MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1) const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{ {
const auto K0 = q_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K1 = q_grid_desc_k0_m_k1.GetLength(I2); const auto K1 = q_grid_desc_k0_m_k1.GetLength(I2);
const auto K = K0 * K1; const auto K = K0 * K1;
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto KBlock = K / Gemm1NPerBlock; // NOTE: QGrad gemm is similar to Y gemm const auto KBlock = K / Gemm1NPerBlock; // NOTE: QGrad gemm is similar to Y gemm
...@@ -659,7 +760,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -659,7 +760,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
template <typename SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_> template <typename SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_>
__device__ static const auto __device__ static const auto
MakeSGradThreadDesc_N0_M_N1(const SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_& sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4) MakeSGradThreadDesc_N0_M_N1(const SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{ {
constexpr auto m0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto m0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto n0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
...@@ -673,8 +775,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -673,8 +775,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto sgrad_thread_desc_n0_m_n1 = transform_tensor_descriptor( constexpr auto sgrad_thread_desc_n0_m_n1 = transform_tensor_descriptor(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)), make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
...@@ -703,6 +805,52 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -703,6 +805,52 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}; };
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
...@@ -774,19 +922,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -774,19 +922,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// set up P / dP Gemm (type 1 rcr) // set up P / dP Gemm (type 1 rcr)
// //
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>( typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -796,7 +938,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -796,7 +938,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op, b_element_op,
b_block_desc_bk0_n_bk1, Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -807,14 +949,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -807,14 +949,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize()); Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = Gemm0::a_block_slice_copy_step;
constexpr auto b_block_slice_copy_step = Gemm0::b_block_slice_copy_step;
const auto a_block_reset_copy_step = const auto a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0); make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
...@@ -828,95 +967,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -828,95 +967,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
LoopScheduler::Default>(); LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / (q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
KPerBlock);
// //
// set up Y / dQ Gemm (type 2 rrr) // set up Y / dQ Gemm (type 2 rrr)
// //
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto acc_thread_desc_k0_m_k1 = Gemm1::acc_thread_desc_k0_m_k1;
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 // A1 matrix in accumulator VGPR, dst of blockwise copy
// n0_n1_n2_n3 -> k0 constexpr auto a1_thread_desc_k0_m_k1 = Gemm1::a_thread_desc_k0_m_k1;
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b1_block_desc_bk0_n_bk1 = Gemm1::b_block_desc_bk0_n_bk1;
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy =
FloatGemmAcc, typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
DataType,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm1::template BBlockwiseCopy<decltype(v_grid_desc_n0_o_n1)>(
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_n0_o_n1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
v_grid_desc_n0_o_n1, v_grid_desc_n0_o_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0), make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
...@@ -927,44 +1002,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -927,44 +1002,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size constexpr index_t Gemm1KPack = Gemm1::GemmKPack;
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm =
BlockSize, typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
DataType,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
...@@ -1003,6 +1048,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1003,6 +1048,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice // get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
constexpr auto thread_slice_desc_m_n = constexpr auto thread_slice_desc_m_n =
...@@ -1422,7 +1478,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1422,7 +1478,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
a_block_desc_ak0_m_ak1, Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -1432,7 +1488,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1432,7 +1488,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
b_block_desc_bk0_n_bk1, Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -1461,6 +1517,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1461,6 +1517,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// //
// set up dQ Gemm (type 2 rrr) // set up dQ Gemm (type 2 rrr)
// //
// transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 = const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1); QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock = auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
...@@ -1468,41 +1526,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1468,41 +1526,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
q_grid_desc_k0_m_k1); q_grid_desc_k0_m_k1);
// dQ Gemm A matrix blockwise copy // dQ Gemm A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto qgrad_gemm_tile_sgrad_blockwise_copy =
FloatGemmAcc, typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
DataType,
decltype(acc_thread_desc_k0_m_k1), // reuse desc
decltype(a1_thread_desc_k0_m_k1), // reuse desc
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// dQ Gemm B matrix blockwise copy // dQ Gemm B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = auto qgrad_gemm_tile_k_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>, // reuse from V
B1BlockTransferThreadClusterLengths_BK0_N_BK1, // reuse from V
B1BlockTransferThreadClusterArrangeOrder, // reuse from V
DataType,
DataType,
decltype(k_grid_desc_n0_k_n1),
decltype(b1_block_desc_bk0_n_bk1), // reuse from V
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
k_grid_desc_n0_k_n1, k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0), make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
...@@ -1510,32 +1539,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1510,32 +1539,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto qgrad_blockwise_gemm = BlockwiseGemmXdlops_v2< auto qgrad_blockwise_gemm =
BlockSize, typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
DataType,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer(); auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// //
// calculate y dot ygrad // calculate Y dot dY
// //
// clear accum buffers // clear accum buffers
...@@ -1632,8 +1642,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1632,8 +1642,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
lse_thread_buf); lse_thread_buf);
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize dQ // Initialize dQ
...@@ -1652,17 +1661,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1652,17 +1661,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
// P = Q * K^T // P = Q * K^T
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1,
a_block_desc_ak0_m_ak1, Gemm0::a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
q_grid_buf, q_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, Gemm0::a_block_slice_copy_step,
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
b_block_desc_bk0_n_bk1, Gemm0::b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
k_grid_buf, k_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, Gemm0::b_block_slice_copy_step,
s_blockwise_gemm, s_blockwise_gemm,
s_slash_p_thread_buf, s_slash_p_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
...@@ -1857,17 +1866,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1857,17 +1866,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); block_sync_lds();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
a_block_desc_ak0_m_ak1, // reuse Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy, pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf, ygrad_grid_buf,
a_block_buf, // reuse a_block_buf, // reuse
a_block_slice_copy_step, // reuse Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
b_block_desc_bk0_n_bk1, // reuse Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy, pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf, v_grid_buf,
b_block_buf, // reuse b_block_buf, // reuse
b_block_slice_copy_step, // reuse Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm, pgrad_blockwise_gemm,
pgrad_thread_buf, pgrad_thread_buf,
num_o_block_main_loop); num_o_block_main_loop);
...@@ -1897,8 +1906,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1897,8 +1906,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto n = constexpr auto n =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1];
// dS and P has same thread buf layout // dS and P has same thread buf layout
sgrad_thread_buf(i) = sgrad_thread_buf(i) = s_slash_p_thread_buf[i] *
s_slash_p_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
#if 0 #if 0
...@@ -1927,7 +1936,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1927,7 +1936,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf); qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
b1_block_slice_copy_step); Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read block_sync_lds(); // wait for previous LDS read
...@@ -1944,13 +1953,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1944,13 +1953,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
qgrad_gemm_tile_sgrad_blockwise_copy.Run( qgrad_gemm_tile_sgrad_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
acc_thread_desc_k0_m_k1, Gemm1::a_block_slice_copy_step * i,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), sgrad_thread_buf,
sgrad_thread_buf, a1_thread_desc_k0_m_k1,
a1_thread_desc_k0_m_k1, make_tuple(I0, I0, I0),
make_tuple(I0, I0, I0), a1_thread_buf);
a1_thread_buf);
#if 0 #if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
...@@ -1971,18 +1979,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1971,18 +1979,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); block_sync_lds();
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
b1_block_slice_copy_step); k_grid_desc_n0_k_n1, Gemm1::b_block_slice_copy_step);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1,
b1_block_buf);
}); });
} }
// tail // tail
{ {
qgrad_gemm_tile_sgrad_blockwise_copy.Run( qgrad_gemm_tile_sgrad_blockwise_copy.Run(
acc_thread_desc_k0_m_k1, acc_thread_desc_k0_m_k1,
make_tuple( Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
...@@ -2011,8 +2019,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -2011,8 +2019,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
#if 0 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
......
...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, index_t N>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Number<N>& y)
{
constexpr index_t NSize = sizeof...(Xs);
// Tuple<Xs...> r;
// static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y; });
// return r;
return generate_tuple([&](auto i) { return x[i] * y; }, Number<NSize>{});
}
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment