Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bbb29a9d
Commit
bbb29a9d
authored
Jul 31, 2024
by
Jing Zhang
Browse files
format
parent
32380a27
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
224 additions
and
223 deletions
+224
-223
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+112
-111
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ltiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
...ly_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+10
-10
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
...ly_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
+10
-10
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+10
-10
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+2
-2
No files found.
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
bbb29a9d
...
@@ -19,133 +19,133 @@ namespace instance {
...
@@ -19,133 +19,133 @@ namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
#endif
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -167,17 +167,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -167,17 +167,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
{
using
DeviceOp
=
DeviceGemmMultipleDSplitK
<
ALayout
,
using
DeviceOp
=
BLayout
,
DeviceGemmMultipleDSplitK
<
ALayout
,
Tuple
<
Row
,
Col
>
,
BLayout
,
CLayout
,
Tuple
<
Row
,
Col
>
,
ADataType
,
CLayout
,
BDataType
,
ADataType
,
Tuple
<
F32
,
F32
>
,
BDataType
,
CDataType
,
Tuple
<
F32
,
F32
>
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
View file @
bbb29a9d
...
@@ -10,16 +10,16 @@ namespace instance {
...
@@ -10,16 +10,16 @@ namespace instance {
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
bbb29a9d
...
@@ -131,16 +131,16 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -131,16 +131,16 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
EDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
bbb29a9d
...
@@ -77,10 +77,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -77,10 +77,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
int
n_warmup
=
1
;
int
n_warmup
=
1
;
int
n_iter
=
10
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
uint64_t
rotating
=
0
;
int
KBatch
=
1
;
int
KBatch
=
1
;
if
(
argc
==
20
)
if
(
argc
==
20
)
{
{
KBatch
=
std
::
stoi
(
argv
[
16
]);
KBatch
=
std
::
stoi
(
argv
[
16
]);
n_warmup
=
std
::
stoi
(
argv
[
17
]);
n_warmup
=
std
::
stoi
(
argv
[
17
]);
n_iter
=
std
::
stoi
(
argv
[
18
]);
n_iter
=
std
::
stoi
(
argv
[
18
]);
rotating
=
std
::
stoull
(
argv
[
19
])
*
1024
*
1024
;
rotating
=
std
::
stoull
(
argv
[
19
])
*
1024
*
1024
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment