Commit 804e6803 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files modified for 1s cold and warm runs

parent 87efbb63
...@@ -20,6 +20,33 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -20,6 +20,33 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
if(ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < stream_config.nrepeat_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float total_time = 0;
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
total_time/=10;
stream_config.cold_niters_ = (1000.0 / total_time);//we need longer runtime to ramp up the clk on MI300s
stream_config.nrepeat_ = stream_config.cold_niters_;
}
#if DEBUG_LOG #if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__, __func__,
......
...@@ -11,6 +11,6 @@ struct StreamConfig ...@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false; bool time_kernel_ = false;
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 5; mutable int cold_niters_ = 5;
int nrepeat_ = 50; mutable int nrepeat_ = 50;
}; };
...@@ -669,450 +669,434 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -669,450 +669,434 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const BElementwiseOperation b_element_op = BElementwiseOperation{}, const BElementwiseOperation b_element_op = BElementwiseOperation{},
const CElementwiseOperation c_element_op = CElementwiseOperation{}) const CElementwiseOperation c_element_op = CElementwiseOperation{})
{ {
for(auto i = 0; i < 1500; i++)
const FloatA* p_a_grid = karg.p_a_grid;
const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [KBatch, M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ {
const FloatA* p_a_grid = karg.p_a_grid; return;
const FloatB* p_b_grid = karg.p_b_grid; }
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(karg.M,
karg.MPadded,
karg.K,
karg.StrideA,
karg.k_batch,
karg.K0Padded,
karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(karg.K,
karg.NPadded,
karg.N,
karg.StrideB,
karg.k_batch,
karg.K0Padded,
karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [KBatch, M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() { constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
constexpr auto a_b_k0_m_k1_block_desc = [&]() { constexpr auto a_b_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1, make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1, Number<MPerBlock + 1>{} * K1,
K1, K1,
I1)); I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align); max_lds_align);
} }
}(); }();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() { constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
constexpr auto b_b_k0_n_k1_block_desc = [&]() { constexpr auto b_b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
LDSTypeA,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
LDSTypeB,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
LDSTypeA,
LDSTypeB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1,
LoopSched,
ComputeTypeA,
ComputeTypeB>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple(
a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
(K0PerBlock * K1));
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_b_k0_n_k1_grid_desc,
b_b_k0_n_k1_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// output: register to global memory
{ {
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); return make_naive_tensor_descriptor(
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = Number<NPerBlock + 1>{} * K1,
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); K1,
I1));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = }
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); else
{
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); return make_naive_tensor_descriptor_aligned(
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); max_lds_align);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); }
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); }();
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); // A matrix blockwise copy
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); auto a_blockwise_copy =
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = ck::tensor_operation::element_wise::PassThrough,
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( ABlockTransferThreadClusterLengths_K0_M_K1,
static_cast<FloatC*>(p_shared_block), ABlockTransferThreadClusterArrangeOrder,
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); FloatA,
LDSTypeA,
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( decltype(a_b_k0_m_k1_grid_desc),
c_block_desc_mblock_mperblock_nblock_nperblock, decltype(a_b_k0_m_k1_block_desc),
make_tuple(make_freeze_transform(I0), // freeze mblock ABlockTransferSrcAccessOrder,
make_unmerge_transform( Sequence<0, 2, 1, 3>,
make_tuple(CShuffleMRepeatPerShuffle, ABlockTransferSrcVectorDim,
M1, 3,
M2, ABlockTransferSrcScalarPerVector,
M3, ABlockTransferDstScalarPerVector_K1,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL 1,
make_freeze_transform(I0), // freeze nblock 1,
make_unmerge_transform( AThreadTransferSrcResetCoordinateAfterRun,
make_tuple(CShuffleNRepeatPerShuffle, true>(
N1, a_b_k0_m_k1_grid_desc,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), a_element_op,
make_tuple(Sequence<>{}, a_b_k0_m_k1_block_desc,
Sequence<0, 2, 4, 5, 6>{}, make_multi_index(0, 0, 0, 0),
Sequence<>{}, ck::tensor_operation::element_wise::PassThrough{});
Sequence<1, 3, 7>{}));
// B matrix blockwise copy
// calculate origin of thread output tensor on global memory auto b_blockwise_copy =
// blockwise GEMM c matrix starting index ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
const auto c_thread_mtx_on_block = BElementwiseOperation,
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; Sequence<1, K0PerBlock, NPerBlock, K1>,
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = FloatB,
make_single_stage_tensor_adaptor( LDSTypeB,
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), decltype(b_b_k0_n_k1_grid_desc),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), decltype(b_b_k0_n_k1_block_desc),
make_tuple(Sequence<0>{})); BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
const auto m_thread_data_on_block_idx = BBlockTransferSrcVectorDim,
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( 3,
make_multi_index(m_thread_data_on_block)); BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = 1,
make_single_stage_tensor_adaptor( 1,
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), BThreadTransferSrcResetCoordinateAfterRun,
make_tuple(Sequence<0, 1, 2>{}), true>(
make_tuple(Sequence<0>{})); b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
const auto n_thread_data_on_block_idx = b_element_op,
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( b_b_k0_n_k1_block_desc,
make_multi_index(n_thread_data_on_block)); make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< // GEMM definition
FloatAcc, // c_mtx += transpose(a_mtx) * b_mtx
FloatC, // a_mtx[K0PerBlock, MPerBlock] is in LDS
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), // b_mtx[K0PerBlock, NPerBlock] is in LDS
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
ck::tensor_operation::element_wise::PassThrough, // register
Sequence<CShuffleMRepeatPerShuffle, // sanity check
CShuffleNRepeatPerShuffle,
I1, auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
I1, BlockSize,
M2, LDSTypeA,
I1, LDSTypeB,
M4, FloatAcc,
I1>, decltype(a_k0_m_k1_block_desc),
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, decltype(b_k0_n_k1_block_desc),
7, MPerXDL,
1, NPerXDL,
InMemoryDataOperationEnum::Set, MRepeat,
1, NRepeat,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, K1,
make_multi_index(0, LoopSched,
0, ComputeTypeA,
m_thread_data_on_block_idx[I1], ComputeTypeB>();
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2], auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4], // LDS allocation for A and B: be careful of alignment
n_thread_data_on_block_idx[I2]), constexpr auto a_block_space_size =
ck::tensor_operation::element_wise::PassThrough{}}; math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
// LDS to global auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation, constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
CGlobalMemoryDataOperation, // DstInMemOp, constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL, auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1, p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData, // gridwise GEMM pipeline
FloatC, // typename DstData, const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
decltype(c_block_desc_mblock_mperblock_nblock_nperblock), (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), (K0PerBlock * K1));
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
false> // bool ThreadTransferDstResetCoordinateAfterRun a_b_k0_m_k1_block_desc,
{c_block_desc_mblock_mperblock_nblock_nperblock, a_blockwise_copy,
make_multi_index(0, 0, 0, 0), a_grid_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, a_block_buf,
make_multi_index(block_m_id, 0, block_n_id, 0), a_block_slice_copy_step,
c_element_op}; b_b_k0_n_k1_grid_desc,
b_b_k0_n_k1_block_desc,
constexpr auto mxdlperwave_forward_step = b_blockwise_copy,
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); b_grid_buf,
constexpr auto nxdlperwave_forward_step = b_block_buf,
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); b_block_slice_copy_step,
constexpr auto nxdlperwave_backward_step = blockwise_gemm,
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); c_thread_buf,
num_k_block_main_loop);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter; // output: register to global memory
{
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr bool nxdlperwave_forward_sweep = constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
constexpr index_t nxdlperwave_value = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
nxdlperwave_forward_sweep
? nxdlperwave_iter constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto nxdlperwave = Number<nxdlperwave_value>{}; constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
// make sure it's safe to do ds_write constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
block_sync_lds(); constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
// VGPR to LDS constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
c_thread_copy_vgpr_to_lds.Run( constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf, constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
c_block_buf);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// make sure it's safe to do ds_read static_cast<FloatC*>(p_shared_block),
block_sync_lds(); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// LDS to global constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_copy_lds_to_global.Run( c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_desc_mblock_mperblock_nblock_nperblock, make_tuple(
c_block_buf, make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); nxdlperwave_forward_step);
}
// move on nxdlperwave dimension else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
mxdlperwave_forward_step); nxdlperwave_backward_step);
} }
}); });
}
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
}
});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment