Commit cee23c47 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 487a05d6
...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult ...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
...@@ -149,14 +149,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -149,14 +149,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 128, 256, 32, 512, 128,
16, 16, 16, 16,
32, 32, 32, 32,
1, 1, 1, 4,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v2, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...@@ -180,6 +180,9 @@ int main(int argc, char* argv[]) ...@@ -180,6 +180,9 @@ int main(int argc, char* argv[])
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
ck::index_t Warmup = 50;
ck::index_t Repeat = 50;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -207,6 +210,26 @@ int main(int argc, char* argv[]) ...@@ -207,6 +210,26 @@ int main(int argc, char* argv[])
KBatch = std::stoi(argv[11]); KBatch = std::stoi(argv[11]);
} }
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
Warmup = std::stoi(argv[12]);
Repeat = std::stoi(argv[13]);
}
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
...@@ -214,6 +237,7 @@ int main(int argc, char* argv[]) ...@@ -214,6 +237,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf( printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
printf("arg10 to 11: Warmup, Repeat\n");
exit(0); exit(0);
} }
...@@ -321,7 +345,14 @@ int main(int argc, char* argv[]) ...@@ -321,7 +345,14 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50}); size_t total_size =
(M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) +
N * sizeof(D1DataType) + M * N * sizeof(EDataType));
int rotate_buf_num =
ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size));
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1 // Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1 // // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2 // // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace ck { namespace ck {
// Compute optimized pipeline // Compute optimized pipeline
// GlobalPrefetchStages: 2 // GlobalPrefetchStages: 3
// LocalPreFillStages: 2 // LocalPreFillStages: 2
// LocalPreFetchStages: 2 // LocalPreFetchStages: 2
// LocalSharedMemoryBuffer: 2 // LocalSharedMemoryBuffer: 2
...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K> template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
...@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
__device__ static constexpr auto HotLoopScheduler() __device__ static constexpr auto HotLoopScheduler()
{ {
#if 0 // constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; // B global + A local
static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i; ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
}); });
#endif
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
// B global static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
}); });
// A global // A global
...@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
// A local // A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { // static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore = i; // ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read // __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
}); // });
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs; StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1, B1 // Global prefetch A1, B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// // Local prefill A1 // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
// // Global prefetch A2 // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_block_buf.At(I0), a_block_buf.At(I0),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_bufs(I0));
}); });
}); });
// Local prefill A2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
// // Global prefetch A3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
do do
{ {
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
block_sync_lds();
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); static_for<0, MRepeat, 1>{}([&](auto m0) {
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_buf));
});
});
a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(mfma_reg_buf), mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg_buf]
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; [Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf] b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset( [Number<b_thread_desc_.CalculateOffset(
...@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
}; };
...@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
LoopFunc(I1, I0); LoopFunc(I1, I0);
i += 2; i += 2;
} while(i < (num_loop - 2)); } while(i < (num_loop - 3));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Even)
{ auto ReadWriteCompFunc = [&](auto mfma_reg, auto local_read_reg) {
block_sync_lds();
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(I1)); b_thread_bufs(local_read_reg));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
auto ReadCompFunc = [&](auto mfma_reg, auto local_read_reg) {
block_sync_lds(); block_sync_lds();
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_reg));
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(I1), a_block_buf.At(local_read_reg),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_bufs(local_read_reg));
}); });
}); });
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
}); });
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency HotLoopScheduler();
// __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} };
else
{ auto CompFunc = [&](auto mfma_reg) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
}); });
};
if constexpr(TailNum == TailNumber::Even)
{
ReadCompFunc(I0, I1);
CompFunc(I1);
}
else if constexpr(TailNum == TailNumber::Odd)
{
ReadWriteCompFunc(I0, I1);
ReadCompFunc(I1, I0);
CompFunc(I0);
} }
} }
......
...@@ -711,40 +711,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -711,40 +711,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 constexpr auto a_lds_block_desc =
? 1 make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
: 32 * 4 / KPerBlock / sizeof(LDSTypeA); make_tuple(AK1Number, Number<KPerBlock>{}, I1));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc, a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple( make_tuple(make_xor_with_modulo_transform(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})), make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
make_pass_through_transform(AK1Number)), make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{})); make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( return a_lds_block_desc_permuted;
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
} }
else // ColumnMajor A else // ColumnMajor A
{ {
...@@ -1223,7 +1202,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1223,7 +1202,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
BlockwiseGemmPipe::GlobalBufferNum>( 2>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -1660,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1660,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
BlockwiseGemmPipe::GlobalBufferNum>( 2>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
......
...@@ -60,47 +60,47 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -60,47 +60,47 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#if 0 void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
#endif instances);
#endif #endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
...@@ -145,8 +145,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -145,8 +145,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
PassThrough, PassThrough,
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#if 0
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -160,7 +160,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -160,7 +160,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -174,7 +174,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -174,7 +174,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -188,7 +188,6 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -188,7 +188,6 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#endif #endif
#endif
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
...@@ -240,14 +239,13 @@ struct DeviceOperationInstanceFactory< ...@@ -240,14 +239,13 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
op_ptrs); op_ptrs);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
op_ptrs); op_ptrs);
#endif
} }
} }
#endif #endif
...@@ -265,14 +263,13 @@ struct DeviceOperationInstanceFactory< ...@@ -265,14 +263,13 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
op_ptrs); op_ptrs);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
op_ptrs); op_ptrs);
#endif
} }
} }
#endif #endif
......
...@@ -5,30 +5,30 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES ...@@ -5,30 +5,30 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp
) )
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES}) add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES})
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances<
v2,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances<
GemmKPadding>{}); v1,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment