Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0b5ad335
Commit
0b5ad335
authored
Feb 11, 2025
by
ozturkosu
Browse files
mix precision f8-bf16 streamk gemm
parent
d377b42c
Changes
19
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
613 additions
and
7 deletions
+613
-7
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
...ration_instance/gpu/gemm_universal_streamk/CMakeLists.txt
+17
-6
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp
...device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp
+99
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp
...sal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
...al_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp
...l_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
...l_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
..._streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp
...streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
...l_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
..._streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp
...streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp
+107
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
...sal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...al_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...l_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
..._streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...l_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+25
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
..._streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+25
-0
profiler/src/profile_gemm_universal_streamk.cpp
profiler/src/profile_gemm_universal_streamk.cpp
+20
-1
No files found.
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
100644 → 100755
View file @
0b5ad335
...
@@ -21,9 +21,7 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
...
@@ -21,9 +21,7 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
...
@@ -44,7 +42,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
...
@@ -44,7 +42,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
...
@@ -65,7 +62,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
...
@@ -65,7 +62,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
...
@@ -101,6 +97,21 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
...
@@ -101,6 +97,21 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
add_instance_library
(
device_gemm_universal_streamk_instance
${
GEMM_UNIVERSAL_STREAMK_INSTANCES
}
)
add_instance_library
(
device_gemm_universal_streamk_instance
${
GEMM_UNIVERSAL_STREAMK_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp
0 → 100755
View file @
0b5ad335
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances
<
GemmNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances
<
Interwave
,
GemmNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp
0 → 100755
View file @
0b5ad335
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
0b5ad335
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/src/profile_gemm_universal_streamk.cpp
100644 → 100755
View file @
0b5ad335
...
@@ -26,6 +26,7 @@ enum struct GemmDataType
...
@@ -26,6 +26,7 @@ enum struct GemmDataType
F8_F16_F16
,
// 4
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
};
};
#define OP_NAME "gemm_universal_streamk"
#define OP_NAME "gemm_universal_streamk"
...
@@ -37,7 +38,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
...
@@ -37,7 +38,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
{
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)
\n
"
);
"comp f8
; 7: f8->bf16,
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
@@ -198,6 +199,24 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
...
@@ -198,6 +199,24 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
{
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Col
{},
Row
{});
}
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else
if
(
data_type
==
GemmDataType
::
F8_F8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F8_F8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
I4
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_I4_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
I4
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
#endif
else
else
{
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment