Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6f24c2d8
Commit
6f24c2d8
authored
Dec 31, 2024
by
aska-0096
Browse files
disable N, K Padding, splitk enabled
parent
f60f9d59
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
64 deletions
+34
-64
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
...gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
+4
-42
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+5
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
...instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
+11
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
...u/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
+6
-6
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
...profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
+3
-1
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
View file @
6f24c2d8
...
@@ -164,13 +164,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -164,13 +164,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
...
@@ -180,29 +175,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -180,29 +175,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
constexpr
auto
ds_read_a_issue_cycle
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// stage 1
// Separate this part?
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
num_dsread_a_mfma
;
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
ignore
=
i
;
...
@@ -215,16 +198,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -215,16 +198,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__builtin_amdgcn_sched_group_barrier
(
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
});
// stage 2
// stage 2
...
@@ -243,22 +221,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -243,22 +221,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
}
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
6f24c2d8
...
@@ -362,6 +362,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -362,6 +362,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
return
false
;
return
false
;
}
}
if
(
arg
.
N
%
NPerBlock
!=
0
||
arg
.
K
%
KPerBlock
!=
0
)
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
6f24c2d8
...
@@ -178,9 +178,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -178,9 +178,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
{
return
math
::
integer_divide_ceil
(
N
,
NLane
*
NWave
);
return
math
::
integer_divide_ceil
(
N
,
NLane
*
NWave
);
}
}
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
,
index_t
KBatch
)
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
)
{
{
return
math
::
integer_divide_ceil
(
K
,
KLane
*
KPack
*
KBatch
);
return
math
::
integer_divide_ceil
(
K
,
KLane
*
KPack
);
}
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
...
@@ -539,7 +539,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -539,7 +539,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MBlock
{
CalculateMBlock
(
M_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)},
NBlock
{
CalculateNBlock
(
N_
)},
BN0Shuffled
{
CalculateBN0Shuffled
(
N_
)},
BN0Shuffled
{
CalculateBN0Shuffled
(
N_
)},
BK0Shuffled
{
CalculateBK0Shuffled
(
K_
,
KBatch_
)}
BK0Shuffled
{
CalculateBK0Shuffled
(
K_
)}
{
{
}
}
...
@@ -649,8 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -649,8 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
// KPack * NLane * KLane * NWave * KRepeat * NRepeat *
K0*
N0
// KPack * NLane * KLane * NWave * KRepeat *
K0*
NRepeat * N0
b_k_split_offset
=
k_id
*
karg
.
KRead
*
NLane
*
NWave
*
NXdlPerWave
;
b_k_split_offset
=
k_id
*
karg
.
KRead
*
NLane
*
NWave
;
}
}
if
(
k_id
<
karg
.
KBatch
-
1
)
if
(
k_id
<
karg
.
KBatch
-
1
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
View file @
6f24c2d8
...
@@ -114,7 +114,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
...
@@ -114,7 +114,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply
>>>&
MultiplyMultiply
>>>&
instances
);
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
1_padding
_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
2_default
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -128,7 +128,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
...
@@ -128,7 +128,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply
>>>&
MultiplyMultiply
>>>&
instances
);
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
2
_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
3
_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -141,7 +141,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
...
@@ -141,7 +141,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
MultiplyMultiply
>>>&
instances
);
instances
);
#if 0
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Col,
...
@@ -156,7 +156,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
...
@@ -156,7 +156,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>&
MultiplyMultiply>>>&
instances);
instances);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
3_default
_instances
(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p
1_padding
_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
...
@@ -184,6 +184,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
...
@@ -184,6 +184,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
MultiplyMultiply>>>&
MultiplyMultiply>>>&
instances);
instances);
#endif
#endif
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -258,18 +259,18 @@ struct DeviceOperationInstanceFactory<
...
@@ -258,18 +259,18 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
op_ptrs
);
op_ptrs
);
#if 0
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
op_ptrs);
op_ptrs);
#endif
}
}
}
}
#endif
#endif
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt
View file @
6f24c2d8
...
@@ -3,18 +3,18 @@ set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
...
@@ -3,18 +3,18 @@ set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
#
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
)
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
#
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library
(
device_gemm_multiply_multiply_weight_preshuffle_instance
${
GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
}
)
add_instance_library
(
device_gemm_multiply_multiply_weight_preshuffle_instance
${
GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
}
)
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
View file @
6f24c2d8
...
@@ -249,7 +249,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -249,7 +249,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_preshuffled
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_preshuffled
.
mData
.
data
());
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
};
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
};
if
(
KBatch
>
0
)
if
(
KBatch
>
0
)
{
{
...
@@ -282,6 +282,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -282,6 +282,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment