"tests/vscode:/vscode.git/clone" did not exist on "2ea1da89ab42bde69caad909048925ca99873400"
Commit 2d55c14c authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor Gemm2

parent 383211ef
...@@ -196,38 +196,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -196,38 +196,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr auto template <typename Gemm2Param>
GetPBlockDescriptor_NBlock_NPerBlock_MBlock_MPerBlock() __host__ __device__ static constexpr auto GetA2BlockDescriptor_M0_N_M1()
{ {
constexpr auto ptrans_block_desc = make_naive_tensor_descriptor_packed(make_tuple( return make_naive_tensor_descriptor(
I1, Number<VGradGemmTile_N_O_M::Free0_N>{}, I1, Number<VGradGemmTile_N_O_M::Sum_M>{})); make_tuple(Number<Gemm2Param::A_M0>{},
Number<Gemm2Param::Free0_N>{},
return ptrans_block_desc; Number<Gemm2Param::A_M1>{}),
make_tuple(Number<Gemm2Param::Free0_N + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_M1>{},
Number<Gemm2Param::A_M1>{},
I1));
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetB2BlockDescriptor_M0_O_M1()
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + return make_naive_tensor_descriptor(
SharedMemTrait::b_block_space_size_aligned) * make_tuple(Number<Gemm2Param::B_M0>{},
sizeof(DataType); Number<Gemm2Param::Free1_O>{},
const index_t gemm1_bytes_end = Number<Gemm2Param::B_M1>{}),
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * make_tuple(Number<Gemm2Param::Free1_O + Gemm2Param::B_LdsPad>{} *
sizeof(DataType); Number<Gemm2Param::B_M1>{},
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned + Number<Gemm2Param::B_M1>{},
SharedMemTrait::ygrad_block_space_size_aligned) * I1));
sizeof(DataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
vgrad_gemm_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -485,11 +477,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -485,11 +477,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{}; static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{};
static constexpr auto AThreadSliceLength_K1 = Number<n4>{}; static constexpr auto AThreadSliceLength_K1 = Number<n4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 = static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1( GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}); ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1)); make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 = static constexpr auto b_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
...@@ -574,68 +571,50 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -574,68 +571,50 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}; };
// dV / dK Gemm (type 3 crr) // dV / dK Gemm (type 3 crr)
// TODO ANT: refactor into Gemm2 // Describes tuning parameter for C2_n_o = A2_n_m * B2_m_o
template <index_t Sum_M_ = MPerXdl * 2> template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_ struct Gemm2Params_N_O_M_
{ {
static constexpr index_t Free0_N = NPerBlock; static constexpr index_t Free0_N = NPerBlock;
static constexpr index_t Free1_O = Gemm1NPerBlock; static constexpr index_t Free1_O = Gemm1NPerBlock;
static constexpr index_t Sum_M = Sum_M_; static constexpr index_t Sum_M = Sum_M_;
static constexpr index_t P_M1 = 8; // P will be row-major static constexpr index_t A_M1 = 8; // P will be row-major
static constexpr index_t P_M0 = Sum_M / P_M1; static constexpr index_t A_M0 = Sum_M / A_M1;
static constexpr index_t P_LdsPad = 0; // how many multiples of M1 per N * M1 elements static constexpr index_t A_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static constexpr index_t YGrad_M1 = 2; // dY assumed row-major, typically =2 for fp16 static constexpr index_t B_M1 = 2; // dY assumed row-major, typically =2 for fp16
static constexpr index_t YGrad_M0 = Sum_M / YGrad_M1; static constexpr index_t B_M0 = Sum_M / B_M1;
static constexpr index_t YGrad_LdsPad = 0; // how many multiples of M1 per N * M1 elements static constexpr index_t B_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static_assert(Sum_M % MPerXdl == 0, ""); static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t YGrad_SrcVectorDim = 1; // Free1_O dimension static constexpr index_t BSrcVectorDim = 1; // Free1_O dimension
static constexpr index_t YGrad_SrcScalarPerVector = 4; static constexpr index_t BSrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2; static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl; static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(math::lcm(P_M1, YGrad_M1), math::max(math::lcm(A_M1, B_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using YGrad_BlockSliceLengths = Sequence<YGrad_M0, Free1_O, YGrad_M1>; using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using YGrad_ThreadClusterLengths = using BThreadClusterLengths =
Sequence<BlockSize / (Free1_O / YGrad_SrcScalarPerVector), Sequence<BlockSize / (Free1_O / BSrcScalarPerVector), Free1_O / BSrcScalarPerVector, 1>;
Free1_O / YGrad_SrcScalarPerVector, using BThreadClusterArrangeOrder = Sequence<0, 2, 1>;
1>;
using YGrad_ThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetPBlockDescriptor_M0_N_M1()
{
constexpr index_t P_M0 = Sum_M / P_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<P_M0>{}, Number<Free0_N>{}, Number<P_M1>{}),
make_tuple(Number<Free0_N + P_LdsPad>{} * Number<P_M1>{}, Number<P_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetYGradBlockDescriptor_M0_O_M1()
{
constexpr index_t YGrad_M0 = Sum_M / YGrad_M1;
return make_naive_tensor_descriptor(
make_tuple(Number<YGrad_M0>{}, Number<Free1_O>{}, Number<YGrad_M1>{}),
make_tuple(
Number<Free1_O + YGrad_LdsPad>{} * Number<YGrad_M1>{}, Number<YGrad_M1>{}, I1));
}
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1_M2_N2()
{ {
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl // perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Sum_M - 1; constexpr index_t m = Gemm2Params_N_O_M::Sum_M - 1;
constexpr index_t m2 = m % MPerXdl; constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves; constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave; constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl // perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t n = Free0_N - 1; constexpr index_t n = Gemm2Params_N_O_M::Free0_N - 1;
constexpr index_t n2 = n % NPerXdl; constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves; constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave; constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave;
...@@ -646,14 +625,212 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -646,14 +625,212 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{}; return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
} }
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1()
{ {
return generate_sequence_v2( return generate_sequence_v2(
[](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); }, [](auto I) { return GetABlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
Number<4>{}); Number<4>{});
} }
using ABlockSliceLengths_M0_N0_M1_N1 = decltype(GetABlockSliceLengths_M0_N0_M1_N1());
};
using Gemm2Params_N_O_M = Gemm2Params_N_O_M_<>; // tune later
// dV / dK Gemm (type 3 crr)
template <typename Gemm2Params_N_O_M, typename ASrcBlockwiseGemm>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); // xdl
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
static constexpr auto N3 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
static constexpr auto N4 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
__host__ __device__ static constexpr auto MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4()
{
const auto M0_ = a_block_desc_m0_n_m1.GetLength(I0);
const auto N_ = a_block_desc_m0_n_m1.GetLength(I1);
const auto M1_ = a_block_desc_m0_n_m1.GetLength(I2);
const auto a_block_desc_m_n = transform_tensor_descriptor(
a_block_desc_m0_n_m1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(M0_, M1_)),
make_pass_through_transform(N_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2)),
make_unmerge_transform(make_tuple(I1, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto c_block_slice_lengths_m0_n0_m1_n1 =
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % c_block_slice_lengths_m0_n0_m1_n1[I2], // mwave
a_thread_origin_on_block_idx[I3] % c_block_slice_lengths_m0_n0_m1_n1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1,
false>;
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_M0_O_M1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
typename Gemm2Params_N_O_M::BBlockSliceLengths,
typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
Sequence<1, 0, 2>,
Gemm2Params_N_O_M::BSrcVectorDim,
2, // DstVectorDim
Gemm2Params_N_O_M::BSrcScalarPerVector,
Gemm2Params_N_O_M::B_M1,
1,
1,
true,
true,
1>;
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
Gemm2Params_N_O_M::GemmNRepeat,
Gemm2Params_N_O_M::GemmORepeat,
Gemm2Params_N_O_M::GemmMPack,
true>; // TranspossC
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm2Params_N_O_M::B_M0, 0, 0);
static constexpr auto c_block_slice_copy_step =
make_multi_index(Gemm2Params_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(-MPerBlock / Gemm2Params_N_O_M::B_M1, 0, 0);
template <typename CGradDesc_N_O>
__host__ __device__ static const auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(CGradDesc_N_O c_grid_desc_n_o)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
c_grid_desc_n_o,
make_tuple(
make_unmerge_transform(make_tuple(I1, Gemm2Params_N_O_M::GemmNWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params_N_O_M::GemmOWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_n0_o0_n1_o1_n2_o2);
return c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4;
}
static constexpr auto c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static const auto GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
}; };
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
...@@ -789,10 +966,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -789,10 +966,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 = static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 = static constexpr auto a2_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1(); GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
static constexpr auto ygrad_block_desc_m0_o_m1 = static constexpr auto b2_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1(); GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
...@@ -802,16 +979,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -802,16 +979,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned = static constexpr auto p_block_space_size_aligned = math::integer_least_multiple(
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align); a2_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple( static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align); b2_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0; static constexpr auto a2_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value; static constexpr auto b2_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction // LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned = static constexpr index_t reduction_space_size_aligned =
...@@ -826,6 +1003,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -826,6 +1003,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
}; };
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(DataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(DataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
vgrad_gemm_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
...@@ -1118,276 +1320,75 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1118,276 +1320,75 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// //
// set up dV / dK Gemm (type 3 crr) // set up dV / dK Gemm (type 3 crr)
// //
using Gemm2 = Gemm2<Gemm2Params_N_O_M, decltype(s_blockwise_gemm)>;
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1 // Gemm2: LDS allocation for A and B: be careful of alignment
// m0, n0 are m/n repeat per wave auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// m1, n1 are number of waves static_cast<DataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
constexpr auto p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto p_block_desc_m0_n_m1 = VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_N0 = p_block_lengths[I1];
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_N1 = p_block_lengths[I3];
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_N2 = p_block_lengths[I5];
constexpr auto P_N3 = p_block_lengths[I6];
constexpr auto P_N4 = p_block_lengths[I7];
constexpr auto p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = [&]() constexpr
{
constexpr auto p_block_desc_m_n = transform_tensor_descriptor(
p_block_desc_m0_n_m1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(VGradGemmTile_N_O_M::P_M0, VGradGemmTile_N_O_M::P_M1)),
make_pass_through_transform(VGradGemmTile_N_O_M::Free0_N)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// variable I1 there static_cast<DataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
return transform_tensor_descriptor( Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
p_block_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, P_M1, P_M2)),
make_unmerge_transform(make_tuple(I1, P_N1, P_N2, P_N3, P_N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
}
();
const auto p_thread_origin_nd_idx_on_block = [&]() {
const auto c_thread_mtx_on_block =
s_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = // dV: transform input and output tensor descriptors
make_single_stage_tensor_adaptor( const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 =
make_tuple(make_merge_transform(make_tuple(P_M0, P_M1, P_M2))), Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = // dV: A matrix VGPR-to-LDS blockwise copy
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
make_multi_index(m_thread_data_on_block)); Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = constexpr auto vgrad_gemm_tile_p_block_slice_window_iterator =
make_single_stage_tensor_adaptor( typename Gemm2::ASrcBlockSliceWindowIterator{};
make_tuple(make_merge_transform(make_tuple(P_N0, P_N1, P_N2, P_N3, P_N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = // dV: B matrix global-to-LDS blockwise copy
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( auto vgrad_gemm_tile_ygrad_blockwise_copy =
make_multi_index(n_thread_data_on_block)); typename Gemm2::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1,
return make_tuple(m_thread_data_on_block_idx[I0], // mrepeat make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
n_thread_data_on_block_idx[I0], // nrepeat o_block_data_idx_on_grid,
m_thread_data_on_block_idx[I1], // mwave 0),
n_thread_data_on_block_idx[I1], // nwave tensor_operation::element_wise::PassThrough{},
m_thread_data_on_block_idx[I2], // xdlops Gemm2::b_block_desc_m0_o_m1,
n_thread_data_on_block_idx[I2], make_multi_index(0, 0, 0),
n_thread_data_on_block_idx[I3], tensor_operation::element_wise::PassThrough{});
n_thread_data_on_block_idx[I4]);
}();
constexpr auto p_block_slice_lengths_m0_n0_m1_n1 =
VGradGemmTile_N_O_M::GetPBlockSliceLengths_M0_N0_M1_N1(); // mrepeat, nrepeat,
// mwaves, nwaves,
// how to properly perform copy for a sub-workgroup?
auto p_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0], // ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1[I1],
I1,
I1,
I1,
P_N2,
I1,
P_N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(
p_thread_origin_nd_idx_on_block[I0],
p_thread_origin_nd_idx_on_block[I1],
p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2],
p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3],
p_thread_origin_nd_idx_on_block[I4],
p_thread_origin_nd_idx_on_block[I5],
p_thread_origin_nd_idx_on_block[I6],
p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}};
constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 =
SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>,
Sequence<0, 1, 2, 3>,
decltype(p_block_slice_lengths_m0_n0_m1_n1),
false>{};
constexpr auto ygrad_block_desc_m0_o_m1 = // dV: blockwise gemm
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1(); auto vgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto ygrad_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
typename VGradGemmTile_N_O_M::YGrad_BlockSliceLengths,
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterLengths,
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder,
DataType,
DataType,
decltype(ygrad_grid_desc_m0_o_m1),
decltype(ygrad_block_desc_m0_o_m1),
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder, // access order == thread
// order
Sequence<1, 0, 2>,
VGradGemmTile_N_O_M::YGrad_SrcVectorDim,
2, // DstVectorDim
VGradGemmTile_N_O_M::YGrad_SrcScalarPerVector,
VGradGemmTile_N_O_M::YGrad_M1,
1,
1,
true,
true,
1>(ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / VGradGemmTile_N_O_M::YGrad_M1,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
ygrad_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto p_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::p_block_space_offset,
p_block_desc_m0_n_m1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
FloatGemmAcc,
decltype(p_block_desc_m0_n_m1),
decltype(ygrad_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat,
VGradGemmTile_N_O_M::GemmORepeat,
VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer(); auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy // dV: C VGPR-to-global copy
// variable I1 there
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o,
make_tuple(
make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmNWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmOWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
constexpr auto vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 = const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
vgrad_blockwise_gemm.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4( Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
vgrad_grid_desc_n0_o0_n1_o1_n2_o2);
const auto vgrad_thread_mtx_on_block_n_o =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
constexpr auto vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
decltype(vgrad_blockwise_gemm)::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto VGrad_N0 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I0);
constexpr auto VGrad_O0 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I1);
constexpr auto VGrad_N1 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I2);
constexpr auto VGrad_O1 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I3);
constexpr auto VGrad_N2 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I4);
constexpr auto VGrad_O2 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I5);
constexpr auto VGrad_O3 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I6);
constexpr auto VGrad_O4 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I7);
const index_t n_thread_data_idx_on_grid = vgrad_thread_mtx_on_block_n_o[I0];
const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + o_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_nd_idx_on_grid = const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(n_thread_data_idx_on_grid)); make_multi_index(I0, block_work_idx[I1], I0, I0, I0, I0, I0, I0);
const auto o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor = auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
make_single_stage_tensor_adaptor( vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
make_tuple(make_merge_transform( vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(VGrad_O0, VGrad_O1, VGrad_O2, VGrad_O3, VGrad_O4))), vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(Sequence<0, 1, 2, 3, 4>{}), tensor_operation::element_wise::PassThrough{});
make_tuple(Sequence<0>{}));
const auto o_thread_data_nd_idx_on_grid = // dK: transform input and output tensor descriptors
o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor.CalculateBottomIndex(
make_multi_index(o_thread_data_idx_on_grid));
auto vgrad_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_multi_index(n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I4]),
tensor_operation::element_wise::PassThrough{});
// p_thread_slice_copy_step will be in for loop
constexpr auto ygrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::YGrad_M0, 0, 0);
constexpr auto ygrad_block_reset_copy_step =
make_multi_index(-MPerBlock / VGradGemmTile_N_O_M::YGrad_M1, 0, 0);
// vgrad gemm output tile
const auto vgrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0);
// //
// set up Y dot dY // set up Y dot dY
// //
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O)); I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc = constexpr auto y_thread_cluster_desc =
...@@ -1617,23 +1618,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1617,23 +1618,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
constexpr auto p_block_slice_lengths_m0_n0_m1_n1 =
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1{};
SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]); s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M; constexpr index_t num_vgrad_gemm_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, ""); static_assert(vgrad_gemm_tile_p_block_slice_window_iterator.GetNumOfAccess() ==
num_vgrad_gemm_loop,
"");
vgrad_thread_buf.Clear(); // TODO: tune gemm2 pipeline
// TODO: tune pipeline
// dV = P^T * dY // dV = P^T * dY
vgrad_thread_buf.Clear();
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
// load VGrad Gemm B // load VGrad Gemm B
ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, ygrad_grid_buf); vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
// load VGrad Gemm A // load VGrad Gemm A
const auto p_nd_idx = const auto p_nd_idx =
sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx); vgrad_gemm_tile_p_block_slice_window_iterator.GetIndexTupleOfNumber(
vgrad_gemm_loop_idx);
constexpr auto mwave_range = constexpr auto mwave_range =
make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]); make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = constexpr auto nwave_range =
...@@ -1641,28 +1647,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1641,28 +1647,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0), make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
s_slash_p_thread_buf, s_slash_p_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf); gemm2_a_block_buf);
} }
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index // p slice window is moved by loop index
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_block_slice_copy_step); ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write block_sync_lds(); // sync before write
ygrad_blockwise_copy.RunWrite(ygrad_block_desc_m0_o_m1, ygrad_block_buf); vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf); vgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, vgrad_thread_buf);
}); // end gemm dV }); // end gemm dV
// atomic_add dV // atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
vgrad_thread_buf, vgrad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
...@@ -1777,10 +1784,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1777,10 +1784,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow( s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_block_reset_copy_step); // rewind M ygrad_grid_desc_m0_o_m1,
Gemm2::b_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_block_slice_copy_step); // step N vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment