Commit 383211ef authored by Anthony Chang's avatar Anthony Chang
Browse files

rearrange gemm0/gemm1

parent d13c92bd
......@@ -149,7 +149,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
......@@ -248,7 +248,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
// This assumption redues implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
......@@ -355,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// P / dP Gemm (type 1 rcr)
// S / dP Gemm (type 1 rcr)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -485,7 +485,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{};
static constexpr auto AThreadSliceLength_K1 = Number<n4>{};
static constexpr auto acc_thread_desc_k0_m_k1 =
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{});
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
......@@ -500,7 +500,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
decltype(acc_thread_desc_k0_m_k1),
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
AThreadSliceLengths_K0_M_K1,
......@@ -574,6 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
// dV / dK Gemm (type 3 crr)
// TODO ANT: refactor into Gemm2
template <index_t Sum_M_ = MPerXdl * 2>
struct VGradGemmTile_N_O_M_
{
......@@ -652,17 +653,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Number<4>{});
}
};
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
......@@ -683,8 +683,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct PGradGemmTile_M_N_O
{
// TODO ANT:
// Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template <typename YGradGridDesc_M0_O_M1_>
__device__ static const auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
......@@ -758,31 +758,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
template <typename SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_>
__device__ static const auto
MakeSGradThreadDesc_N0_M_N1(const SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
constexpr auto m0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto sgrad_thread_desc_n0_m_n1 = transform_tensor_descriptor(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return sgrad_thread_desc_n0_m_n1;
}
template <typename KGridDesc_K0_N_K1_>
__device__ static const auto
MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
......@@ -919,11 +894,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
//
// set up P / dP Gemm (type 1 rcr)
// set up S / dP Gemm (type 1 rcr)
//
// A matrix blockwise copy
auto a_blockwise_copy =
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gemm0_gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
......@@ -932,8 +922,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
// S: B matrix blockwise copy
auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
......@@ -942,76 +932,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
const auto a_block_reset_copy_step =
const auto s_gemm_tile_a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
const auto s_gemm_tile_b_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
// dP: transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP: Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
//
// set up Y / dQ Gemm (type 2 rrr)
//
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_k0_m_k1 = Gemm1::acc_thread_desc_k0_m_k1;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
// A1 matrix in accumulator VGPR, dst of blockwise copy
constexpr auto a1_thread_desc_k0_m_k1 = Gemm1::a_thread_desc_k0_m_k1;
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = Gemm1::b_block_desc_bk0_n_bk1;
// dQ: transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// A1 matrix blockwise copy
auto a1_blockwise_copy =
// dQ: Gemm A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(v_grid_desc_n0_o_n1)>(
v_grid_desc_n0_o_n1,
// dQ: Gemm B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
Gemm1::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = Gemm1::GemmKPack;
auto gemm1_blockwise_gemm =
// dQ: blockwise gemm
auto qgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
//
// Blockwise softmax
......@@ -1391,19 +1407,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<DataType,
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVetor, // SrcScalarPerVector
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(
y_grid_desc_mblock_mperblock_oblock_operblock, y_thread_data_on_grid_idx);
true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
......@@ -1435,80 +1451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// set up dQ Gemm (type 2 rrr)
//
// transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// dQ Gemm A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ Gemm B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto qgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
//
// calculate Y dot dY
//
......@@ -1586,18 +1531,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
continue;
}
// P = Q * K^T
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1,
// S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
q_grid_desc_k0_m_k1,
Gemm0::a_block_desc_ak0_m_ak1,
a_blockwise_copy,
s_gemm_tile_q_blockwise_copy,
q_grid_buf,
a_block_buf,
gemm0_a_block_buf,
Gemm0::a_block_slice_copy_step,
k_grid_desc_k0_n_k1,
Gemm0::b_block_desc_bk0_n_bk1,
b_blockwise_copy,
s_gemm_tile_k_blockwise_copy,
k_grid_buf,
b_block_buf,
gemm0_b_block_buf,
Gemm0::b_block_slice_copy_step,
s_blockwise_gemm,
s_slash_p_thread_buf,
......@@ -1679,7 +1625,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_thread_buf.Clear();
// TODO ANT: single buffer prefetch pipeline
// TODO: tune pipeline
// dV = P^T * dY
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
// load VGrad Gemm B
ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, ygrad_grid_buf);
......@@ -1714,7 +1661,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf);
}); // end gemm dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
......@@ -1723,26 +1669,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_grid_buf);
// gemm dP
// assume size K == size O so HasMainKBlockLoop is the same
block_sync_lds();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf,
a_block_buf, // reuse
gemm0_a_block_buf, // reuse
Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf,
b_block_buf, // reuse
gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
// calculate dS from dP
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
......@@ -1760,6 +1707,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
// gemm dQ
// dQ = dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
......@@ -1776,54 +1724,59 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
qgrad_gemm_tile_sgrad_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
a1_thread_desc_k0_m_k1,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
gemm1_a_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf);
qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
block_sync_lds();
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_n0_k_n1, Gemm1::b_block_slice_copy_step);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1,
b1_block_buf);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy.Run(
acc_thread_desc_k0_m_k1,
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
sgrad_thread_buf,
a1_thread_desc_k0_m_k1,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
gemm1_a_thread_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf);
qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
}
} // end gemm dQ
// move slice window
a_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_k0_m_k1,
a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_k0_n_k1,
b_block_reset_copy_step); // rewind K and step N
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
s_gemm_tile_a_block_reset_copy_step); // rewind K
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1,
ygrad_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......@@ -1836,7 +1789,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// TODO ANT:
// shuffle dQ and write
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......@@ -1848,12 +1800,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
qgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
qgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
......@@ -1893,7 +1845,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
qgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment