Commit c798cff9 authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor gemm0

parent db7f7bed
...@@ -641,7 +641,10 @@ int run(int argc, char* argv[]) ...@@ -641,7 +641,10 @@ int run(int argc, char* argv[])
std::cout << "Checking qgrad:\n"; std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData, pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData); qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n"; std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData, pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData); kgrad_gs_ns_ks_host_result.mData);
......
...@@ -446,26 +446,124 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -446,26 +446,124 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major) struct SharedMemTrait
struct PGradGemmTile_M_N_O
{ {
private: // LDS allocation for A and B: be careful of alignment
static constexpr auto ygrad_block_desc_o0_m_o1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto v_block_desc_o0_n_o1 = static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
// P / dP Gemm (type 1 rcr)
struct Gemm0
{
private:
static constexpr auto a_block_desc = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
public:
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_M_K1,
decltype(a_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max( static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
public: // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm =
BlockSize, BlockwiseGemmXdlops_v2<BlockSize,
DataType, DataType,
FloatGemmAcc, FloatGemmAcc,
decltype(ygrad_block_desc_o0_m_o1), decltype(a_block_desc),
decltype(v_block_desc_o0_n_o1), decltype(b_block_desc),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(ygrad_block_desc_o0_m_o1)), decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(v_block_desc_o0_n_o1)), a_block_desc)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(
b_block_desc)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -474,8 +572,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -474,8 +572,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
true>; true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO ANT:
// Should have made all input tensors 2D and transform them into appropriate 3D form in // Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave // kernel to make things more concise - if we can get the compiler to behave
template <typename YGradGridDesc_M0_O_M1_> template <typename YGradGridDesc_M0_O_M1_>
...@@ -597,52 +703,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -597,52 +703,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}; };
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
...@@ -703,47 +763,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -703,47 +763,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return; return;
} }
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t gemm1_n_block_data_idx_on_grid = const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
//
// set up P / dP Gemm (type 1 rcr)
//
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(q_grid_desc_k0_m_k1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -753,28 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -753,28 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(k_grid_desc_k0_n_k1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op, b_element_op,
...@@ -782,35 +800,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -782,35 +800,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -821,8 +813,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -821,8 +813,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = Gemm0::a_block_slice_copy_step;
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = Gemm0::b_block_slice_copy_step;
const auto a_block_reset_copy_step = const auto a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0); make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = const auto b_block_reset_copy_step =
...@@ -839,12 +832,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -839,12 +832,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
KPerBlock); KPerBlock);
// //
// set up O / dQ Gemm // set up Y / dQ Gemm (type 2 rrr)
// //
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
...@@ -873,7 +866,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -873,7 +866,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// A1 matrix in AccVGPR // A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 = constexpr auto AccN3 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 = constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{}); make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
...@@ -889,47 +882,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -889,47 +882,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy // A1 matrix blockwise copy
// auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
// FloatGemmAcc, FloatGemmAcc,
// DataType, DataType,
// decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
// decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
// tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
// Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>, Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
// Sequence<1, 0, 2>, Sequence<1, 0, 2>,
// 2, 2,
// n4>{tensor_operation::element_wise::PassThrough{}}; n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
// auto b1_blockwise_copy = auto b1_blockwise_copy =
// ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
// BElementwiseOperation, BElementwiseOperation,
// tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
// InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// Sequence<B1K0, Gemm1NPerBlock, B1K1>, Sequence<B1K0, Gemm1NPerBlock, B1K1>,
// B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
// B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
// DataType, DataType,
// DataType, DataType,
// decltype(v_grid_desc_n0_o_n1), decltype(v_grid_desc_n0_o_n1),
// decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
// B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
// Sequence<1, 0, 2>, Sequence<1, 0, 2>,
// B1BlockTransferSrcVectorDim, B1BlockTransferSrcVectorDim,
// 2, 2,
// B1BlockTransferSrcScalarPerVector, B1BlockTransferSrcScalarPerVector,
// B1BlockTransferDstScalarPerVector_BK1, B1BlockTransferDstScalarPerVector_BK1,
// 1, 1,
// 1, 1,
// B1ThreadTransferSrcResetCoordinateAfterRun, B1ThreadTransferSrcResetCoordinateAfterRun,
// true, // DstResetCoord true, // DstResetCoord
// NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
// v_grid_desc_n0_o_n1, v_grid_desc_n0_o_n1,
// make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), make_multi_index(0, o_block_data_idx_on_grid, 0),
// b1_element_op, b1_element_op,
// b1_block_desc_bk0_n_bk1, b1_block_desc_bk0_n_bk1,
// make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
// tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); a1_thread_desc_k0_m_k1.GetElementSpaceSize());
...@@ -984,8 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -984,8 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// get acc0 8D thread cluster // get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
...@@ -1021,23 +1014,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1021,23 +1014,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true>
c_thread_buf;
c_thread_buf.Clear();
// Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m); MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
...@@ -1047,7 +1023,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1047,7 +1023,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>( auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize()); lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr = auto lse_thread_copy_global_to_vgpr =
...@@ -1068,19 +1044,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1068,19 +1044,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
// //
// dV // set up dV / dK Gemm (type 3 crr)
// //
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1 // P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave // m0, n0 are m/n repeat per wave
// m1, n1 are number of waves // m1, n1 are number of waves
constexpr auto p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto p_block_desc_m0_n_m1 = VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1(); constexpr auto p_block_desc_m0_n_m1 = VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
constexpr auto p_block_lengths = constexpr auto p_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_N0 = p_block_lengths[I1]; constexpr auto P_N0 = p_block_lengths[I1];
constexpr auto P_M1 = p_block_lengths[I2]; // waves constexpr auto P_M1 = p_block_lengths[I2]; // waves
...@@ -1113,7 +1089,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1113,7 +1089,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto p_thread_origin_nd_idx_on_block = [&]() { const auto p_thread_origin_nd_idx_on_block = [&]() {
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); s_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -1184,17 +1160,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1184,17 +1160,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_origin_nd_idx_on_block[I7]), p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
// I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 = constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 =
SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>, SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -1229,7 +1194,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1229,7 +1194,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, true,
1>(ygrad_grid_desc_m0_o_m1, 1>(ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / VGradGemmTile_N_O_M::YGrad_M1, make_multi_index(m_block_data_idx_on_grid / VGradGemmTile_N_O_M::YGrad_M1,
gemm1_n_block_data_idx_on_grid, o_block_data_idx_on_grid,
0), 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
ygrad_block_desc_m0_o_m1, ygrad_block_desc_m0_o_m1,
...@@ -1292,7 +1257,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1292,7 +1257,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t n_thread_data_idx_on_grid = vgrad_thread_mtx_on_block_n_o[I0]; const index_t n_thread_data_idx_on_grid = vgrad_thread_mtx_on_block_n_o[I0];
const index_t o_thread_data_idx_on_grid = const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid; vgrad_thread_mtx_on_block_n_o[I1] + o_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))), make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
...@@ -1375,8 +1340,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1375,8 +1340,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif #endif
// //
// dP // set up Y dot dY
// //
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O)); I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc = constexpr auto y_thread_cluster_desc =
...@@ -1424,35 +1390,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1424,35 +1390,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// per-thread LSE data and y_dot_ygrad is // per-thread LSE data and y_dot_ygrad is
// tiled the same way // tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 = const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1); PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 = const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1); PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP Gemm A matrix blockwise copy // dP Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy = auto pgrad_gemm_tile_ygrad_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(ygrad_grid_desc_o0_m_o1),
decltype(a_block_desc_ak0_m_ak1), // reuse block buf
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -1460,30 +1426,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1460,30 +1426,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// dP Gemm B matrix blockwise copy // dP Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy = auto pgrad_gemm_tile_v_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_o0_n_o1),
decltype(b_block_desc_bk0_n_bk1), // reuse block buf
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -1491,7 +1436,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1491,7 +1436,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{}; auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer(); auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step = const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0); make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
...@@ -1502,25 +1447,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1502,25 +1447,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) / (ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock); KPerBlock);
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
#if 0 #if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n"); if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0) if(hipBlockIdx_x == 0)
...@@ -1533,7 +1459,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1533,7 +1459,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize()); y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
// //
// dQ // set up dQ Gemm (type 2 rrr)
// //
const auto k_grid_desc_n0_k_n1 = const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1); QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
...@@ -1578,7 +1504,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1578,7 +1504,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
k_grid_desc_n0_k_n1, k_grid_desc_n0_k_n1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
b1_block_desc_bk0_n_bk1, b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1607,9 +1533,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1607,9 +1533,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer(); auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// //
// calculate y dot ygrad // calculate y dot ygrad
// //
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0; index_t oblock_idx = 0;
do do
{ {
...@@ -1673,7 +1605,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1673,7 +1605,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
#endif #endif
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely accessed after barrier // distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run( y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl, y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf, y_dot_ygrad_block_accum_buf,
...@@ -1681,7 +1613,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1681,7 +1613,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
#if 1 #if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4) if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{ {
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n", printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
...@@ -1700,6 +1632,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1700,6 +1632,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
lse_thread_buf); lse_thread_buf);
const index_t num_gemm1_k_block_outer_loop =
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1715,7 +1650,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1715,7 +1650,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
continue; continue;
} }
// gemm0 // P = Q * K^T
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
...@@ -1728,8 +1663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1728,8 +1663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
k_grid_buf, k_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, s_blockwise_gemm,
acc_thread_buf, s_slash_p_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
...@@ -1737,11 +1672,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1737,11 +1672,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope // 8d block_desc in block scope
constexpr auto c_block_lengths = constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0]; constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1]; constexpr auto N0 = c_block_lengths[I1];
...@@ -1760,7 +1695,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1760,7 +1695,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved false>; // SnakeCurved
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
...@@ -1779,11 +1714,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1779,11 +1714,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{ {
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
} }
else else
{ {
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); acc_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
} }
}); });
} }
...@@ -1800,32 +1735,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1800,32 +1735,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n", printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x, hipThreadIdx_x,
acc_thread_buf[I0], s_slash_p_thread_buf[I0],
acc_thread_buf[I1], s_slash_p_thread_buf[I1],
acc_thread_buf[I2], s_slash_p_thread_buf[I2],
acc_thread_buf[I3]); s_slash_p_thread_buf[I3]);
} }
#endif #endif
// P_i: = softmax(S_i:) // P_i: = softmax(S_i:)
blockwise_softmax.RunWithPreCalcStats(acc_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
#if 0 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n", printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x, hipThreadIdx_x,
acc_thread_buf[I0], s_slash_p_thread_buf[I0],
acc_thread_buf[I1], s_slash_p_thread_buf[I1],
acc_thread_buf[I2], s_slash_p_thread_buf[I2],
acc_thread_buf[I3]); s_slash_p_thread_buf[I3]);
} }
#endif #endif
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
blockwise_gemm.GetWaveIdx()[I1]); s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M; constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, ""); static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
...@@ -1864,7 +1799,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1864,7 +1799,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_vgpr_to_lds.Run( p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0), make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
acc_thread_buf, s_slash_p_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf); p_block_buf);
} }
...@@ -1936,7 +1871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1936,7 +1871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm, pgrad_blockwise_gemm,
pgrad_thread_buf, pgrad_thread_buf,
num_o_block_main_loop); num_o_block_main_loop);
#if 1 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n", printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
...@@ -1963,10 +1898,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1963,10 +1898,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1];
// dS and P has same thread buf layout // dS and P has same thread buf layout
sgrad_thread_buf(i) = sgrad_thread_buf(i) =
acc_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); s_slash_p_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
#if 1 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n", printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
...@@ -2016,7 +1951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -2016,7 +1951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
#if 1 #if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n", printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
...@@ -2079,7 +2014,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -2079,7 +2014,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
#if 1 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4) if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n", printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment