Unverified Commit af30d6b6 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

Merge pull request #1838 from ROCm/cka8w8_uc_newpipe

Cka8w8 uc newpipe
parents b755f375 add0b222
...@@ -227,13 +227,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) ...@@ -227,13 +227,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
endif() endif()
endif() endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
if(HAS_ENABLE_POST_MISCHED)
message("Adding the enable-post-misched=0 compiler flag")
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
endif()
endif()
set(check-coerce) set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
......
...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult ...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
...@@ -139,27 +139,16 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -139,27 +139,16 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// clang-format off // clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 256, 256, 128,
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>| 16, 16,
///###### RCR 32, 32,
// kernel 1: 256->32x128x128 4, 4,
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
32, 128, 256,
16, 16,
32, 32,
1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -180,6 +169,9 @@ int main(int argc, char* argv[]) ...@@ -180,6 +169,9 @@ int main(int argc, char* argv[])
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
ck::index_t Warmup = 50;
ck::index_t Repeat = 50;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -207,6 +199,26 @@ int main(int argc, char* argv[]) ...@@ -207,6 +199,26 @@ int main(int argc, char* argv[])
KBatch = std::stoi(argv[11]); KBatch = std::stoi(argv[11]);
} }
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
Warmup = std::stoi(argv[12]);
Repeat = std::stoi(argv[13]);
}
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
...@@ -214,6 +226,7 @@ int main(int argc, char* argv[]) ...@@ -214,6 +226,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf( printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
printf("arg10 to 11: Warmup, Repeat\n");
exit(0); exit(0);
} }
...@@ -321,7 +334,14 @@ int main(int argc, char* argv[]) ...@@ -321,7 +334,14 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50}); size_t total_size =
(M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) +
N * sizeof(D1DataType) + M * N * sizeof(EDataType));
int rotate_buf_num =
ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size));
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
namespace ck { namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer, template <BlockGemmPipelineVersion BlkGemmPipelineVer,
...@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else else
{ {
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
......
...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K> template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1 // Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1 // // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2 // // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace ck { namespace ck {
// Compute optimized pipeline // Compute optimized pipeline
// GlobalPrefetchStages: 2 // GlobalPrefetchStages: 3
// LocalPreFillStages: 2 // LocalPreFillStages: 2
// LocalPreFetchStages: 2 // LocalPreFetchStages: 2
// LocalSharedMemoryBuffer: 2 // LocalSharedMemoryBuffer: 2
...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K> template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
...@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
__device__ static constexpr auto HotLoopScheduler() __device__ static constexpr auto HotLoopScheduler()
{ {
#if 0 // constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; // B global + A local
static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i; ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
}); });
#endif
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
// B global static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) {
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A
}); });
// A global // A global
...@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
// A local // A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { // static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore = i; // ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read // __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
}); // });
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs; StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1, B1 // Global prefetch A1, B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// // Local prefill A1 // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
// // Global prefetch A2 // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_block_buf.At(I0), a_block_buf.At(I0),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_bufs(I0));
}); });
}); });
// Local prefill A2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
// // Global prefetch A3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
do do
{ {
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
block_sync_lds();
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); static_for<0, MRepeat, 1>{}([&](auto m0) {
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_buf));
});
});
a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(mfma_reg_buf), mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg_buf]
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; [Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf] b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset( [Number<b_thread_desc_.CalculateOffset(
...@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
}; };
...@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
LoopFunc(I1, I0); LoopFunc(I1, I0);
i += 2; i += 2;
} while(i < (num_loop - 2)); } while(i < (num_loop - 3));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Even)
{ auto ReadWriteCompFunc = [&](auto mfma_reg, auto local_read_reg) {
block_sync_lds();
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(I1)); b_thread_bufs(local_read_reg));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
auto ReadCompFunc = [&](auto mfma_reg, auto local_read_reg) {
block_sync_lds(); block_sync_lds();
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_reg));
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(I1), a_block_buf.At(local_read_reg),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_bufs(local_read_reg));
}); });
}); });
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
}); });
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency HotLoopScheduler();
// __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} };
else
{ auto CompFunc = [&](auto mfma_reg) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I ...@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
}); });
}); });
}); });
};
if constexpr(TailNum == TailNumber::Even)
{
ReadCompFunc(I0, I1);
CompFunc(I1);
}
else if constexpr(TailNum == TailNumber::Odd)
{
ReadWriteCompFunc(I0, I1);
ReadCompFunc(I1, I0);
CompFunc(I0);
} }
} }
......
...@@ -296,7 +296,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -296,7 +296,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
} }
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
...@@ -351,7 +352,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -351,7 +352,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
else else
{ {
throw std::runtime_error("todo: only v1 & v2 support now"); throw std::runtime_error("todo: only v1 v2 and v3 support now");
} }
} }
#if 0 #if 0
...@@ -359,6 +360,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -359,6 +360,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{ {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -405,8 +407,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -405,8 +407,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run(kernel); Run(kernel);
} }
} }
#endif
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
...@@ -602,7 +625,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -602,7 +625,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{BlockGemmPipelineScheduler::Interwave, "Interwave"}}; {BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{ std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}}; {BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"}};
// clang-format off // clang-format off
str << "DeviceGemmXdlUniversal" str << "DeviceGemmXdlUniversal"
......
...@@ -172,7 +172,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -172,7 +172,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NWave * warpSize == BlockSize);
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
...@@ -711,40 +710,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -711,40 +710,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 constexpr auto a_lds_block_desc =
? 1 make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
: 32 * 4 / KPerBlock / sizeof(LDSTypeA); make_tuple(AK1Number, Number<KPerBlock>{}, I1));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc, a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple( make_tuple(make_xor_with_modulo_transform(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})), make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
make_pass_through_transform(AK1Number)), make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{})); make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( return a_lds_block_desc_permuted;
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
} }
else // ColumnMajor A else // ColumnMajor A
{ {
...@@ -1242,13 +1220,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1242,13 +1220,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>, Sequence<1, 2, 0, 3>,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled, true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id(), get_warp_local_1d_id() % NWave,
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
...@@ -1660,7 +1638,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1660,7 +1638,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
BlockwiseGemmPipe::GlobalBufferNum>( 2>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -1682,13 +1660,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1682,13 +1660,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>, Sequence<1, 2, 0, 3>,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled, true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id(), get_warp_local_1d_id() % NWave,
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
......
...@@ -60,47 +60,61 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -60,47 +60,61 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#if 0 void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
#endif instances);
#endif #endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
...@@ -145,8 +159,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -145,8 +159,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
PassThrough, PassThrough,
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#if 0
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -160,7 +174,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -160,7 +174,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -174,7 +188,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -174,7 +188,21 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -187,7 +215,6 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -187,7 +215,6 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
PassThrough, PassThrough,
MultiplyMultiply>>>& MultiplyMultiply>>>&
instances); instances);
#endif
#endif #endif
template <typename ADataType, template <typename ADataType,
...@@ -240,14 +267,16 @@ struct DeviceOperationInstanceFactory< ...@@ -240,14 +267,16 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
op_ptrs); op_ptrs);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances(
op_ptrs); op_ptrs);
#endif
} }
} }
#endif #endif
...@@ -265,14 +294,16 @@ struct DeviceOperationInstanceFactory< ...@@ -265,14 +294,16 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
op_ptrs); op_ptrs);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
op_ptrs); op_ptrs);
#endif
} }
} }
#endif #endif
......
...@@ -5,30 +5,34 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES ...@@ -5,30 +5,34 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
# device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instance.cpp device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp
) )
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
# set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES}) add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
v1,
GemmDefault>{}); GemmDefault>{});
} }
......
...@@ -8,7 +8,7 @@ namespace tensor_operation { ...@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_ ...@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
GemmKPadding>{}); v2,
GemmDefault>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment