"...resnet50_tensorflow.git" did not exist on "2f43cff2b72e9a5bee26d31e4e8af3087a5618e1"
Commit 26d5174e authored by aska-0096's avatar aska-0096
Browse files

update instance and lds layout strategy

parent ea90b01f
...@@ -615,11 +615,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -615,11 +615,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// bank conflict when writting the data into LDS, but don't worry, we have whole entire loop to hide it in v4.
// it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number), make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1)); make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
} }
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases. // in some cases.
...@@ -752,11 +754,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -752,11 +754,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// bank conflict when writting the data into LDS, but don't worry, we have whole entire loop to hide it in v4.
// it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number), make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
......
...@@ -676,11 +676,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -676,11 +676,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number), make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1)); make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
} }
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases. // in some cases.
...@@ -813,11 +815,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -813,11 +815,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number), make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
...@@ -1216,6 +1220,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1216,6 +1220,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
__device__ static constexpr auto EpilogueScheduler()
{
constexpr auto epilogue_tile = MPerBlock * NPerBlock * CShuffleMXdlPerWavePerShuffle *
CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave);
constexpr auto num_mfma_inst = BlockwiseGemmPipe::HotLoopInstList::C_MFMA_Inst_Num *
CShuffleMXdlPerWavePerShuffle *
CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave);
constexpr auto num_ds_write_inst =
epilogue_tile / BlockSize; // DefaultMFMA, per-element write
constexpr auto num_ds_read_inst =
epilogue_tile / BlockSize / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto num_buffer_store_inst = num_ds_read_inst;
// MFMA:ds_write=1:2
constexpr auto num_ds_write_issue = num_ds_write_inst / 2;
constexpr auto num_mfma_block_sync = (num_mfma_inst - num_ds_write_issue) / 2;
constexpr auto mfma_ds_write_rate = MXdlPerWave == 16 ? 2 : 4;
// Hide ds_write issue latency
static_for<0, num_ds_write_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, mfma_ds_write_rate, 0); // DS write
});
// Hide block_sync + ds_read latency
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst, 0); // DS read
// Hide block_sync latency
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x040, num_buffer_store_inst, 0); // VMEM write
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942 // if arch = gfx942
using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
...@@ -1393,6 +1429,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1393,6 +1429,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_;
constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_;
constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
b_thread_desc.GetElementSpaceSize());
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
...@@ -1410,10 +1455,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1410,10 +1455,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
a_thread_buf,
b_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
// Last block MFMA
auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm;
constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); "wrong!");
...@@ -1573,6 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1573,6 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
// C: LDS -> VGPR
// D: Global -> VGPR
// E: =Epilogue(C, D), VGPR -> Global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
...@@ -1631,10 +1685,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1631,10 +1685,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
if constexpr(access_id < num_access - 1)
{
constexpr auto shuffle_m0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}];
constexpr auto shuffle_n0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}];
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
...@@ -1668,6 +1789,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1668,6 +1789,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
EpilogueScheduler();
} }
}); });
} }
...@@ -1860,6 +1983,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1860,6 +1983,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_;
constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_;
constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
b_thread_desc.GetElementSpaceSize());
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
...@@ -1877,10 +2009,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1877,10 +2009,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_bufs, b_block_bufs,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
a_thread_buf,
b_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
// Last block MFMA
auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm;
constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); "wrong!");
...@@ -2098,10 +2236,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -2098,10 +2236,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
if constexpr(access_id < num_access - 1)
{
constexpr auto shuffle_m0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}];
constexpr auto shuffle_n0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}];
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
...@@ -2135,6 +2340,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -2135,6 +2340,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
EpilogueScheduler();
} }
}); });
} }
......
...@@ -17,7 +17,7 @@ namespace tensor_operation { ...@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -30,7 +30,33 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst ...@@ -30,7 +30,33 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -221,9 +247,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -221,9 +247,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{ {
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment