Commit 383211ef authored by Anthony Chang's avatar Anthony Chang
Browse files

rearrange gemm0/gemm1

parent d13c92bd
...@@ -149,7 +149,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -149,7 +149,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1( __host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4) const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{ {
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0 // n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m // m0_m1_m2 -> m
// n4 -> k1 // n4 -> k1
...@@ -248,7 +248,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -248,7 +248,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1); const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
// This assumption redues implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
...@@ -355,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -355,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// P / dP Gemm (type 1 rcr) // S / dP Gemm (type 1 rcr)
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -485,7 +485,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -485,7 +485,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{}; static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{};
static constexpr auto AThreadSliceLength_K1 = Number<n4>{}; static constexpr auto AThreadSliceLength_K1 = Number<n4>{};
static constexpr auto acc_thread_desc_k0_m_k1 = static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1( GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}); ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{});
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
...@@ -500,7 +500,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -500,7 +500,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
DataType, DataType,
decltype(acc_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
AThreadSliceLengths_K0_M_K1, AThreadSliceLengths_K0_M_K1,
...@@ -574,6 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -574,6 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}; };
// dV / dK Gemm (type 3 crr) // dV / dK Gemm (type 3 crr)
// TODO ANT: refactor into Gemm2
template <index_t Sum_M_ = MPerXdl * 2> template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_ struct VGradGemmTile_N_O_M_
{ {
...@@ -652,17 +653,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -652,17 +653,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Number<4>{}); Number<4>{});
} }
}; };
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVetor>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{}; static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M = static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{}; Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
...@@ -683,8 +683,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -683,8 +683,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct PGradGemmTile_M_N_O struct PGradGemmTile_M_N_O
{ {
// TODO ANT: // TODO ANT:
// Should have made all input tensors 2D and transform them into appropriate 3D form in // Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// kernel to make things more concise - if we can get the compiler to behave // things more concise
template <typename YGradGridDesc_M0_O_M1_> template <typename YGradGridDesc_M0_O_M1_>
__device__ static const auto __device__ static const auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1) MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
...@@ -758,31 +758,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -758,31 +758,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
} }
template <typename SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_>
__device__ static const auto
MakeSGradThreadDesc_N0_M_N1(const SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
constexpr auto m0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto sgrad_thread_desc_n0_m_n1 = transform_tensor_descriptor(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return sgrad_thread_desc_n0_m_n1;
}
template <typename KGridDesc_K0_N_K1_> template <typename KGridDesc_K0_N_K1_>
__device__ static const auto __device__ static const auto
MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1) MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
...@@ -919,11 +894,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -919,11 +894,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// //
// set up P / dP Gemm (type 1 rcr) // set up S / dP Gemm (type 1 rcr)
// //
// A matrix blockwise copy // Gemm0: LDS allocation for A and B: be careful of alignment
auto a_blockwise_copy = auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gemm0_gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>( typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
...@@ -932,8 +922,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -932,8 +922,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // S: B matrix blockwise copy
auto b_blockwise_copy = auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>( typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, 0, 0), // will loop over GemmN dimension
...@@ -942,76 +932,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -942,76 +932,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer(); auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment const auto s_gemm_tile_a_block_reset_copy_step =
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
const auto a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0); make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = const auto s_gemm_tile_b_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0); make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock); (q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
// dP: transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP: Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
// //
// set up Y / dQ Gemm (type 2 rrr) // set up Y / dQ Gemm (type 2 rrr)
// //
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using Gemm1 = using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()), Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type // Gemm1: VGPR allocation for A and LDS allocation for B
constexpr auto acc_thread_desc_k0_m_k1 = Gemm1::acc_thread_desc_k0_m_k1; auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
// A1 matrix in accumulator VGPR, dst of blockwise copy auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
constexpr auto a1_thread_desc_k0_m_k1 = Gemm1::a_thread_desc_k0_m_k1; static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// B1 matrix in LDS memory, dst of blockwise copy // dQ: transform input and output tensor descriptors
constexpr auto b1_block_desc_bk0_n_bk1 = Gemm1::b_block_desc_bk0_n_bk1; const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// A1 matrix blockwise copy // dQ: Gemm A matrix blockwise copy
auto a1_blockwise_copy = auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}}; typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // dQ: Gemm B matrix blockwise copy
auto b1_blockwise_copy = auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(v_grid_desc_n0_o_n1)>( typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
v_grid_desc_n0_o_n1, k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0), make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
b1_block_desc_bk0_n_bk1, Gemm1::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( // dQ: blockwise gemm
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); auto qgrad_blockwise_gemm =
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = Gemm1::GemmKPack;
auto gemm1_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// //
// Blockwise softmax // Blockwise softmax
...@@ -1391,19 +1407,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1391,19 +1407,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_thread_data_on_block_idx; y_thread_data_on_block_idx;
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
ThreadwiseTensorSliceTransfer_v2<DataType, DataType,
DataType, DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, // SrcVectorDim 3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVetor, // SrcScalarPerVector YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>( true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_desc_mblock_mperblock_oblock_operblock, y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
...@@ -1435,80 +1451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1435,80 +1451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>( auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize()); y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// set up dQ Gemm (type 2 rrr)
//
// transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// dQ Gemm A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ Gemm B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto qgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// //
// calculate Y dot dY // calculate Y dot dY
// //
...@@ -1586,22 +1531,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1586,22 +1531,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
continue; continue;
} }
// P = Q * K^T // S = Q * K^T
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1, gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
Gemm0::a_block_desc_ak0_m_ak1, q_grid_desc_k0_m_k1,
a_blockwise_copy, Gemm0::a_block_desc_ak0_m_ak1,
q_grid_buf, s_gemm_tile_q_blockwise_copy,
a_block_buf, q_grid_buf,
Gemm0::a_block_slice_copy_step, gemm0_a_block_buf,
k_grid_desc_k0_n_k1, Gemm0::a_block_slice_copy_step,
Gemm0::b_block_desc_bk0_n_bk1, k_grid_desc_k0_n_k1,
b_blockwise_copy, Gemm0::b_block_desc_bk0_n_bk1,
k_grid_buf, s_gemm_tile_k_blockwise_copy,
b_block_buf, k_grid_buf,
Gemm0::b_block_slice_copy_step, gemm0_b_block_buf,
s_blockwise_gemm, Gemm0::b_block_slice_copy_step,
s_slash_p_thread_buf, s_blockwise_gemm,
num_k_block_main_loop); s_slash_p_thread_buf,
num_k_block_main_loop);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
...@@ -1679,7 +1625,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1679,7 +1625,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_thread_buf.Clear(); vgrad_thread_buf.Clear();
// TODO ANT: single buffer prefetch pipeline // TODO: tune pipeline
// dV = P^T * dY
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
// load VGrad Gemm B // load VGrad Gemm B
ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, ygrad_grid_buf); ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, ygrad_grid_buf);
...@@ -1714,7 +1661,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1714,7 +1661,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf); vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf);
}); // end gemm dV }); // end gemm dV
// atomic_add dV // atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
...@@ -1723,26 +1669,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1723,26 +1669,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_grid_buf); vgrad_grid_buf);
// gemm dP // gemm dP
// assume size K == size O so HasMainKBlockLoop is the same
block_sync_lds(); block_sync_lds();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( // dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
Gemm0::a_block_desc_ak0_m_ak1, // reuse Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy, pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf, ygrad_grid_buf,
a_block_buf, // reuse gemm0_a_block_buf, // reuse
Gemm0::a_block_slice_copy_step, // reuse Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy, pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf, v_grid_buf,
b_block_buf, // reuse gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm, pgrad_blockwise_gemm,
pgrad_thread_buf, pgrad_thread_buf,
num_o_block_main_loop); num_o_block_main_loop);
// calculate dS from dP // dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf; auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator = constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator(); pgrad_blockwise_gemm.MakeCThreadTileIterator();
...@@ -1760,6 +1707,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1760,6 +1707,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
// gemm dQ // gemm dQ
// dQ = dS * K
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
...@@ -1776,54 +1724,59 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1776,54 +1724,59 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for previous LDS read block_sync_lds(); // wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body // main body
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
qgrad_gemm_tile_sgrad_blockwise_copy.Run(acc_thread_desc_k0_m_k1, qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i, Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf, sgrad_thread_buf,
a1_thread_desc_k0_m_k1, Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); gemm1_a_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf); qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
block_sync_lds(); block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf); qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
block_sync_lds(); block_sync_lds();
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow( qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_n0_k_n1, Gemm1::b_block_slice_copy_step); k_grid_desc_n0_k_n1, Gemm1::b_block_slice_copy_step);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
b1_block_buf); gemm1_b_block_buf);
}); });
} }
// tail // tail
{ {
qgrad_gemm_tile_sgrad_blockwise_copy.Run( qgrad_gemm_tile_sgrad_blockwise_copy.Run(
acc_thread_desc_k0_m_k1, Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{}, Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
sgrad_thread_buf, sgrad_thread_buf,
a1_thread_desc_k0_m_k1, Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); gemm1_a_thread_buf);
block_sync_lds(); block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf); qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
} }
} // end gemm dQ } // end gemm dQ
// move slice window // move slice window
a_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_k0_m_k1, s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
a_block_reset_copy_step); // rewind K q_grid_desc_k0_m_k1,
b_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_k0_n_k1, s_gemm_tile_a_block_reset_copy_step); // rewind K
b_block_reset_copy_step); // rewind K and step N s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1, ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1,
ygrad_block_reset_copy_step); // rewind M ygrad_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
...@@ -1836,7 +1789,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1836,7 +1789,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// TODO ANT:
// shuffle dQ and write // shuffle dQ and write
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
...@@ -1848,12 +1800,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1848,12 +1800,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); qgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it! // TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); qgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
...@@ -1893,7 +1845,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1893,7 +1845,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); qgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment