Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
32380a27
Commit
32380a27
authored
Jul 31, 2024
by
Jing Zhang
Browse files
add ckProfiler
parent
1675a341
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
37 additions
and
31 deletions
+37
-31
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+12
-12
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ltiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
...ly_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...tiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
...iply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
...ly_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
+1
-1
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+3
-1
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+12
-8
No files found.
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
32380a27
...
@@ -18,7 +18,7 @@ namespace device {
...
@@ -18,7 +18,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -31,7 +31,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
...
@@ -31,7 +31,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -44,7 +44,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
...
@@ -44,7 +44,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -57,7 +57,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_in
...
@@ -57,7 +57,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_in
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -70,7 +70,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_i
...
@@ -70,7 +70,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_i
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -83,7 +83,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_in
...
@@ -83,7 +83,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_in
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -96,7 +96,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_i
...
@@ -96,7 +96,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_i
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -109,7 +109,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding
...
@@ -109,7 +109,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -122,7 +122,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_in
...
@@ -122,7 +122,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_in
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -135,7 +135,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
...
@@ -135,7 +135,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
@@ -154,7 +154,7 @@ template <typename ADataType,
...
@@ -154,7 +154,7 @@ template <typename ADataType,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
SplitK
<
ALayout
,
ALayout
,
BLayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
using
DeviceOp
=
DeviceGemmMultipleD
SplitK
<
ALayout
,
BLayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
CLayout
,
CLayout
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
View file @
32380a27
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
SplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
32380a27
...
@@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
int
StrideD0
,
int
StrideD0
,
int
StrideD1
,
int
StrideD1
,
int
StrideE
,
int
StrideE
,
int
KBatch
,
int
n_warmup
,
int
n_warmup
,
int
n_iter
,
int
n_iter
,
uint64_t
rotating
=
0
)
uint64_t
rotating
=
0
)
...
@@ -129,7 +130,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -129,7 +130,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
SplitK
<
ALayout
,
BLayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ELayout
,
...
@@ -199,6 +200,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -199,6 +200,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
StrideB
,
StrideB
,
std
::
array
<
ck
::
index_t
,
2
>
{
StrideD0
,
StrideD1
},
std
::
array
<
ck
::
index_t
,
2
>
{
StrideD0
,
StrideD1
},
StrideE
,
StrideE
,
KBatch
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
...
...
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
32380a27
...
@@ -34,7 +34,7 @@ enum struct GemmDataType
...
@@ -34,7 +34,7 @@ enum struct GemmDataType
int
profile_gemm_multiply_multiply
(
int
argc
,
char
*
argv
[])
int
profile_gemm_multiply_multiply
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!=
16
&&
argc
!=
19
)
if
(
argc
!=
16
&&
argc
!=
20
)
{
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
...
@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE
\n
"
);
printf
(
"arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg16: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg16: number of kbatch (default 1)
\n
"
);
printf
(
"arg17: number of iterations (default 10)
\n
"
);
printf
(
"arg17: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg18: memory for rotating buffer (default 0, size in MB)
\n
"
);
printf
(
"arg18: number of iterations (default 10)
\n
"
);
printf
(
"arg19: memory for rotating buffer (default 0, size in MB)
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
int
n_warmup
=
1
;
int
n_warmup
=
1
;
int
n_iter
=
10
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
uint64_t
rotating
=
0
;
if
(
argc
==
18
)
int
KBatch
=
1
;
if
(
argc
==
20
)
{
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
KBatch
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
n_warmup
=
std
::
stoi
(
argv
[
17
]);
rotating
=
std
::
stoull
(
argv
[
18
])
*
1024
*
1024
;
n_iter
=
std
::
stoi
(
argv
[
18
]);
rotating
=
std
::
stoull
(
argv
[
19
])
*
1024
*
1024
;
}
}
using
F32
=
float
;
using
F32
=
float
;
...
@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
(
StrideD0
<
0
)
?
DefaultStrideD0
:
StrideD0
,
(
StrideD0
<
0
)
?
DefaultStrideD0
:
StrideD0
,
(
StrideD1
<
0
)
?
DefaultStrideD1
:
StrideD1
,
(
StrideD1
<
0
)
?
DefaultStrideD1
:
StrideD1
,
(
StrideE
<
0
)
?
DefaultStrideE
:
StrideE
,
(
StrideE
<
0
)
?
DefaultStrideE
:
StrideE
,
KBatch
,
n_warmup
,
n_warmup
,
n_iter
,
n_iter
,
rotating
);
rotating
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment