Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cee23c47
Commit
cee23c47
authored
Jan 17, 2025
by
aska-0096
Browse files
tempsave
parent
487a05d6
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
311 additions
and
299 deletions
+311
-299
example/65_gemm_multiply_multiply/CMakeLists.txt
example/65_gemm_multiply_multiply/CMakeLists.txt
+1
-1
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
...y_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
+37
-6
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
+107
-119
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+8
-29
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
...instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
+48
-51
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
...u/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
+12
-12
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
..._multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
+35
-32
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
...ffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
...ffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
...ffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
...y_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
+35
-32
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
...eshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
...uffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
...eshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
...uffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
...eshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
+1
-0
No files found.
example/65_gemm_multiply_multiply/CMakeLists.txt
View file @
cee23c47
...
...
@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
)
#
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
target_compile_options
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
View file @
cee23c47
...
...
@@ -149,14 +149,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
32
,
12
8
,
256
,
32
,
5
12
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
1
,
FP8
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
2
,
FP8
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...
...
@@ -180,6 +180,9 @@ int main(int argc, char* argv[])
ck
::
index_t
KBatch
=
1
;
ck
::
index_t
Warmup
=
50
;
ck
::
index_t
Repeat
=
50
;
if
(
argc
==
1
)
{
// use default case
...
...
@@ -207,6 +210,26 @@ int main(int argc, char* argv[])
KBatch
=
std
::
stoi
(
argv
[
11
]);
}
else
if
(
argc
==
14
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
KBatch
=
std
::
stoi
(
argv
[
11
]);
Warmup
=
std
::
stoi
(
argv
[
12
]);
Repeat
=
std
::
stoi
(
argv
[
13
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
...
...
@@ -214,6 +237,7 @@ int main(int argc, char* argv[])
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch
\n
"
);
printf
(
"arg10 to 11: Warmup, Repeat
\n
"
);
exit
(
0
);
}
...
...
@@ -321,7 +345,14 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
,
0
,
50
,
50
,
true
,
50
});
size_t
total_size
=
(
M
*
K
*
sizeof
(
A0DataType
)
+
N
*
K
*
sizeof
(
B0DataType
)
+
M
*
sizeof
(
D0DataType
)
+
N
*
sizeof
(
D1DataType
)
+
M
*
N
*
sizeof
(
EDataType
));
int
rotate_buf_num
=
ck
::
math
::
min
(
size_t
(
Repeat
),
ck
::
math
::
integer_divide_ceil
(
512
*
1024
*
1024
,
total_size
));
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
,
0
,
Warmup
,
Repeat
,
true
,
rotate_buf_num
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
View file @
cee23c47
...
...
@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
// Global prefetch A1 B1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
...
...
@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
__builtin_amdgcn_sched_barrier
(
0
);
// // Local prefill A1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
// // Global prefetch A2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Local prefetch A1
...
...
@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
mfma_reg_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@
namespace
ck
{
// Compute optimized pipeline
// GlobalPrefetchStages:
2
// GlobalPrefetchStages:
3
// LocalPreFillStages: 2
// LocalPreFetchStages: 2
// LocalSharedMemoryBuffer: 2
...
...
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
template
<
typename
TileDesc_M0_M1_M2_K
>
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
...
...
@@ -183,80 +183,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
__device__
static
constexpr
auto
HotLoopScheduler
()
{
#if 0
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// B global + A local
static_for
<
0
,
num_buffer_load_inst_b
/
2
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read B
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read A
});
#endif
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
// B global
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_buffer_load_inst_b
/
2
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read B
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read A
});
// A global
...
...
@@ -269,11 +213,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
// A local
static_for
<
0
,
num_ds_read_inst_a
/
2
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
2
,
0
);
// DS read
});
//
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
//
ignore = i;
//
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
//
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
//
});
}
template
<
bool
HasMainLoop
,
...
...
@@ -311,11 +255,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
// Global prefetch A1, B1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
...
...
@@ -325,11 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
//
//
Local prefill A1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
// Local prefill A1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
)
,
I0
);
//
//
Global prefetch A2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// Global prefetch A2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Local prefetch A1
...
...
@@ -341,10 +286,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
s
(
I0
)
);
});
});
// Local prefill A2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
),
I1
);
// // Global prefetch A3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
...
...
@@ -357,17 +309,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
do
{
auto
LoopFunc
=
[
&
](
auto
mfma_reg_buf
,
auto
local_read_buf
)
{
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
local_read_buf
));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_bufs
(
local_read_buf
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
mfma_reg_buf
),
mfma_reg_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -378,8 +343,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
...
...
@@ -401,19 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
};
...
...
@@ -422,18 +375,32 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
LoopFunc
(
I1
,
I0
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
3
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
auto
ReadWriteCompFunc
=
[
&
](
auto
mfma_reg
,
auto
local_read_reg
)
{
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
));
b_thread_bufs
(
local_read_reg
));
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
local_read_reg
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_bufs
(
local_read_reg
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
)
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
mfma_reg
),
mfma_reg
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
@@ -443,10 +410,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
...
...
@@ -463,21 +430,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
};
auto
ReadCompFunc
=
[
&
](
auto
mfma_reg
,
auto
local_read_reg
)
{
block_sync_lds
();
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_reg
));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf
.
At
(
I1
),
a_block_buf
.
At
(
local_read_reg
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
s
(
local_read_reg
)
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
@@ -486,10 +462,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
...
...
@@ -505,12 +481,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
//
__builtin_amdgcn_sched_barrier(0);
}
else
{
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
;
auto
CompFunc
=
[
&
](
auto
mfma_reg
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
@@ -519,10 +495,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
s
[
mfma_reg
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_bufs
[
mfma_reg
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
...
...
@@ -538,6 +514,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
});
});
});
};
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadCompFunc
(
I0
,
I1
);
CompFunc
(
I1
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I0
,
I1
);
ReadCompFunc
(
I1
,
I0
);
CompFunc
(
I0
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
cee23c47
...
...
@@ -711,40 +711,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
>
{},
Number
<
AK0Number
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
return
a_lds_block_desc_permuted
;
}
else
// ColumnMajor A
{
...
...
@@ -1223,7 +1202,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
2
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
@@ -1660,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
2
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
View file @
cee23c47
...
...
@@ -60,47 +60,47 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply
>>>&
instances
);
#if 0
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>&
instances);
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
padding
_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>&
instances);
#endif
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
...
...
@@ -145,8 +145,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#if 0
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
2_padding
_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
1_default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -160,7 +160,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply
>>>&
instances
);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
1_padding
_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
2_default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -174,7 +174,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply
>>>&
instances
);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instances(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -188,7 +188,6 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply
>>>&
instances
);
#endif
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -240,14 +239,13 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances
(
op_ptrs
);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
default
_instances
_v2
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instances
_v2
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
default
_instances
_v2
(
op_ptrs
);
#endif
}
}
#endif
...
...
@@ -265,14 +263,13 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
op_ptrs
);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
default
_instances
_v2
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
default
_instances
_v2
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instances
_v2
(
op_ptrs
);
#endif
}
}
#endif
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
View file @
cee23c47
...
...
@@ -5,30 +5,30 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
padding
_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
padding
_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instance
_v2
.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
padding
_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
padding
_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
default
_instance
_v2
.cpp
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
padding
_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_
default
_instance
_v2
.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_weight_preshuffle_instance
${
GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
View file @
cee23c47
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
View file @
cee23c47
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances
<
v1
,
GemmDefault
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
padding
_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances
<
GemmKPadding
>
{});
v2
,
GemmDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
View file @
cee23c47
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances
<
v1
,
GemmDefault
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
padding
_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances
<
GemmKPadding
>
{});
v2
,
GemmDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
View file @
cee23c47
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances
<
v1
,
GemmDefault
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instance
_v2
.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances
<
GemmKPadding
>
{});
v2
,
GemmDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
View file @
cee23c47
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances
<
v2
,
GemmDefault
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
padding
_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
default
_instance
_v2
.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_
default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances
<
GemmKPadding
>
{});
v1
,
GemmDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
View file @
cee23c47
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances
<
v1
,
GemmDefault
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instance
_v2
.cpp
View file @
cee23c47
...
...
@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_
default
_instances
_v2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -24,7 +24,8 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances
<
GemmKPadding
>
{});
v2
,
GemmDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
View file @
cee23c47
...
...
@@ -24,6 +24,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances
<
v1
,
GemmDefault
>
{});
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment