Commit c798cff9 authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor gemm0

parent db7f7bed
......@@ -641,7 +641,10 @@ int run(int argc, char* argv[])
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData);
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData);
......
......@@ -446,36 +446,142 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
struct SharedMemTrait
{
private:
static constexpr auto ygrad_block_desc_o0_m_o1 =
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto v_block_desc_o0_n_o1 =
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
// P / dP Gemm (type 1 rcr)
struct Gemm0
{
private:
static constexpr auto a_block_desc = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
public:
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
FloatGemmAcc,
decltype(ygrad_block_desc_o0_m_o1),
decltype(v_block_desc_o0_n_o1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(ygrad_block_desc_o0_m_o1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(v_block_desc_o0_n_o1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>;
DataType,
GridDesc_K0_M_K1,
decltype(a_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm =
BlockwiseGemmXdlops_v2<BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc),
decltype(b_block_desc),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(
a_block_desc)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(
b_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO ANT:
// Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave
template <typename YGradGridDesc_M0_O_M1_>
......@@ -597,52 +703,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop,
typename Block2CTileMap,
typename C0MatrixMask,
......@@ -703,47 +763,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return;
}
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t gemm1_n_block_data_idx_on_grid =
const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
//
// set up P / dP Gemm (type 1 rcr)
//
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(q_grid_desc_k0_m_k1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
......@@ -753,28 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(k_grid_desc_k0_n_k1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op,
......@@ -782,35 +800,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>{}; // TransposeC
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -821,8 +813,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
constexpr auto a_block_slice_copy_step = Gemm0::a_block_slice_copy_step;
constexpr auto b_block_slice_copy_step = Gemm0::b_block_slice_copy_step;
const auto a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
......@@ -839,12 +832,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
KPerBlock);
//
// set up O / dQ Gemm
// set up Y / dQ Gemm (type 2 rrr)
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
......@@ -873,7 +866,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
......@@ -889,47 +882,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
// auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
// FloatGemmAcc,
// DataType,
// decltype(acc_thread_desc_k0_m_k1),
// decltype(a1_thread_desc_k0_m_k1),
// tensor_operation::element_wise::PassThrough,
// Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
// Sequence<1, 0, 2>,
// 2,
// n4>{tensor_operation::element_wise::PassThrough{}};
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
// auto b1_blockwise_copy =
// ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
// BElementwiseOperation,
// tensor_operation::element_wise::PassThrough,
// InMemoryDataOperationEnum::Set,
// Sequence<B1K0, Gemm1NPerBlock, B1K1>,
// B1BlockTransferThreadClusterLengths_BK0_N_BK1,
// B1BlockTransferThreadClusterArrangeOrder,
// DataType,
// DataType,
// decltype(v_grid_desc_n0_o_n1),
// decltype(b1_block_desc_bk0_n_bk1),
// B1BlockTransferSrcAccessOrder,
// Sequence<1, 0, 2>,
// B1BlockTransferSrcVectorDim,
// 2,
// B1BlockTransferSrcScalarPerVector,
// B1BlockTransferDstScalarPerVector_BK1,
// 1,
// 1,
// B1ThreadTransferSrcResetCoordinateAfterRun,
// true, // DstResetCoord
// NumGemmKPrefetchStage>(
// v_grid_desc_n0_o_n1,
// make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
// b1_element_op,
// b1_block_desc_bk0_n_bk1,
// make_multi_index(0, 0, 0),
// tensor_operation::element_wise::PassThrough{});
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_n0_o_n1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
v_grid_desc_n0_o_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
......@@ -984,8 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
......@@ -1021,23 +1014,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true>
c_thread_buf;
c_thread_buf.Clear();
// Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
......@@ -1047,7 +1023,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
......@@ -1068,19 +1044,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin[I4])}; // mperxdl
//
// dV
// set up dV / dK Gemm (type 3 crr)
//
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto p_block_desc_m0_n_m1 = VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
constexpr auto p_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_N0 = p_block_lengths[I1];
constexpr auto P_M1 = p_block_lengths[I2]; // waves
......@@ -1113,7 +1089,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto p_thread_origin_nd_idx_on_block = [&]() {
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
s_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......@@ -1184,17 +1160,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
// I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 =
SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>,
Sequence<0, 1, 2, 3>,
......@@ -1229,7 +1194,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true,
1>(ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / VGradGemmTile_N_O_M::YGrad_M1,
gemm1_n_block_data_idx_on_grid,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
ygrad_block_desc_m0_o_m1,
......@@ -1292,7 +1257,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t n_thread_data_idx_on_grid = vgrad_thread_mtx_on_block_n_o[I0];
const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
vgrad_thread_mtx_on_block_n_o[I1] + o_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
......@@ -1375,8 +1340,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif
//
// dP
// set up Y dot dY
//
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
......@@ -1424,35 +1390,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP Gemm A matrix blockwise copy
// dP Gemm A position blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(ygrad_grid_desc_o0_m_o1),
decltype(a_block_desc_ak0_m_ak1), // reuse block buf
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
......@@ -1460,30 +1426,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP Gemm B matrix blockwise copy
// dP Gemm B position blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_o0_n_o1),
decltype(b_block_desc_bk0_n_bk1), // reuse block buf
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
......@@ -1491,7 +1436,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{};
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
......@@ -1502,25 +1447,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
......@@ -1533,7 +1459,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// dQ
// set up dQ Gemm (type 2 rrr)
//
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
......@@ -1578,7 +1504,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // DstResetCoord
NumGemmKPrefetchStage>(
k_grid_desc_n0_k_n1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
......@@ -1607,9 +1533,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
//
// calculate y dot ygrad
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0;
do
{
......@@ -1673,7 +1605,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
#endif
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely accessed after barrier
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
......@@ -1681,7 +1613,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
#if 1
#if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
......@@ -1700,6 +1632,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
const index_t num_gemm1_k_block_outer_loop =
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1715,7 +1650,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
continue;
}
// gemm0
// P = Q * K^T
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
......@@ -1728,8 +1663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
k_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
acc_thread_buf,
s_blockwise_gemm,
s_slash_p_thread_buf,
num_k_block_main_loop);
// do MNK padding or upper triangular masking
......@@ -1737,11 +1672,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
......@@ -1760,7 +1695,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
......@@ -1779,11 +1714,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]);
acc_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
......@@ -1800,32 +1735,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
s_slash_p_thread_buf[I0],
s_slash_p_thread_buf[I1],
s_slash_p_thread_buf[I2],
s_slash_p_thread_buf[I3]);
}
#endif
// P_i: = softmax(S_i:)
blockwise_softmax.RunWithPreCalcStats(acc_thread_buf, lse_thread_buf);
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
s_slash_p_thread_buf[I0],
s_slash_p_thread_buf[I1],
s_slash_p_thread_buf[I2],
s_slash_p_thread_buf[I3]);
}
#endif
block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0],
blockwise_gemm.GetWaveIdx()[I1]);
SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
......@@ -1864,7 +1799,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
acc_thread_buf,
s_slash_p_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf);
}
......@@ -1936,7 +1871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
#if 1
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
......@@ -1963,10 +1898,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1];
// dS and P has same thread buf layout
sgrad_thread_buf(i) =
acc_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
s_slash_p_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
});
#if 1
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
......@@ -2016,7 +1951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
#if 1
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
......@@ -2079,7 +2014,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// shuffle dQ and write
#if 1
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment