Commit 988478d4 authored by chenjun's avatar chenjun
Browse files

edit fp8 ab scale for Scale_Block_M=1

parent f728087c
...@@ -26,7 +26,6 @@ using S = ck::Sequence<Is...>; ...@@ -26,7 +26,6 @@ using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using FP8 = ck::f8_t; using FP8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -68,11 +67,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ ...@@ -68,11 +67,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128,
128, 16, 16, 128, 16, 16,
32, 32, 16, 16,
2, 2, 4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on // clang-format on
...@@ -83,9 +82,9 @@ int main(int argc, char* argv[]) ...@@ -83,9 +82,9 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 128;
ck::index_t N = 4096; ck::index_t N = 1024;
ck::index_t K = 4096; ck::index_t K = 1024;
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
...@@ -101,7 +100,7 @@ int main(int argc, char* argv[]) ...@@ -101,7 +100,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -111,9 +110,9 @@ int main(int argc, char* argv[]) ...@@ -111,9 +110,9 @@ int main(int argc, char* argv[])
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]); StrideA = K;
StrideB = std::stoi(argv[8]); StrideB = K;
StrideE = std::stoi(argv[9]); StrideE = N;
} }
else else
{ {
...@@ -185,20 +184,10 @@ int main(int argc, char* argv[]) ...@@ -185,20 +184,10 @@ int main(int argc, char* argv[])
case 4: case 4:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0}); a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0}); // b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break; b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
case 5:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 6:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
......
...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
true>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
true>;
using Base::I0; using Base::I0;
using Base::KRepeat; using Base::KRepeat;
using Base::xdlops_gemm; using Base::xdlops_gemm;
...@@ -338,18 +340,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -338,18 +340,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0){ // a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf, a_scale_grid_buf,
a_scale_thread_desc, a_scale_thread_desc,
make_tuple(m0, I0), make_tuple(m0, I0),
a_scale_thread_buf); a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{})); a_scale_thread_copy_step.At(Number<0>{}));
}); });
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{})); if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -357,6 +373,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -357,6 +373,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
...@@ -468,18 +485,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -468,18 +485,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_thread_buf); b_thread_buf);
}); });
}); });
static_for<0,MRepeat,1>{}([&](auto m0){ // a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf, a_scale_grid_buf,
a_scale_thread_desc, a_scale_thread_desc,
make_tuple(m0, I0), make_tuple(m0, I0),
a_scale_thread_buf); a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{})); a_scale_thread_copy_step.At(Number<0>{}));
}); });
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{})); if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -487,6 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -487,6 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -1363,16 +1363,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1363,16 +1363,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
auto a_thread_offset =
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
// auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 128) * MPerXdl;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
// auto a_thread_offset =
// get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl;
auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 128) * MPerXdl;
auto a_scale_thread_copy = auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType, ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType, AScaleType,
...@@ -1384,7 +1384,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1384,7 +1384,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
1, 1,
1, 1,
false>( false>(
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
...@@ -1399,8 +1400,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1399,8 +1400,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
false>( false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr auto a_scale_thread_slice_copy_step = constexpr auto a_scale_thread_slice_copy_step =
make_tuple(make_multi_index(MWaves * MPerXdl, 0), make_multi_index(-MPerBlock, 1)); make_tuple(make_multi_index(MWaves * MPerXdl, 0),
make_multi_index(-MPerBlock, 0),
make_multi_index(-MPerBlock, 1));
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
...@@ -1443,24 +1447,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1443,24 +1447,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// transposed XDL
// // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// // TODO: hacky, fix it!
// only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
...@@ -1469,24 +1477,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1469,24 +1477,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
static_cast<CShuffleDataType*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(
make_freeze_transform(I0), make_freeze_transform(I0),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl M2)), // M2 = MPerXdl
M3,
M4)),
make_freeze_transform(I0), make_freeze_transform(I0),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave N1, // N1 = NWave
N2))), // N2 = NPerXdl N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -1496,57 +1504,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1496,57 +1504,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block)); make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType, CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ck::tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
I1, I1,
I1, I1,
M2,
I1, I1,
M4, N2,
I1>, I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, make_multi_index(0,
0, 0,
m_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2], m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3], n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4], n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I4]),
ck::tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType; using EDataType = CDataType;
...@@ -1628,18 +1636,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1628,18 +1636,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
1, 1,
1, 1,
M2,
1, 1,
M4, N2,
1>>{}; 1,
N4>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
...@@ -1659,10 +1666,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1659,10 +1666,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment