Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
896f8b4c
Commit
896f8b4c
authored
Jan 10, 2025
by
Jakub Piasecki
Browse files
add gemm_api and instances
parent
73a076ee
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1204 additions
and
10 deletions
+1204
-10
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+26
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+26
-5
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+63
-4
example/ck_tile/03_gemm/instances/gemm_api.cpp
example/ck_tile/03_gemm/instances/gemm_api.cpp
+482
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp
...instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp
...instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp
...instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp
+26
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp
...instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp
...mm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp
...mm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp
...mm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp
+26
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp
...mm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp
...03_gemm/instances/gemm_universal_comp_instance_common.hpp
+206
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
+26
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp
...emm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp
...emm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp
+27
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
...emm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
+26
-0
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
896f8b4c
# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
function
(
add_gemm_example TARGET_NAME MAIN_SRC
)
message
(
"adding
${
TARGET_NAME
}
"
)
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
${
TARGET_NAME
}
EXCLUDE_FROM_ALL
${
MAIN_SRC
}
)
target_include_directories
(
${
TARGET_NAME
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
foreach
(
source IN LISTS ARGN
)
list
(
APPEND INSTANCE_SRCS
${
source
}
)
endforeach
()
target_sources
(
${
TARGET_NAME
}
PRIVATE
${
INSTANCE_SRCS
}
)
set
(
COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
target_compile_options
(
${
TARGET_NAME
}
PRIVATE
${
COMPILE_OPTIONS
}
)
endfunction
(
add_gemm_example TARGET_NAME MAIN_SRC
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_gemm_example
(
tile_example_gemm_universal universal_gemm.cpp
${
INSTANCE_SRCS
}
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -9,13 +9,10 @@
...
@@ -9,13 +9,10 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_
calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
@@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return
ave_time
;
return
ave_time
;
}
}
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Row
,
Row
,
Row
>
(
args
,
s
);
}
else
if
(
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Row
,
Col
,
Row
>
(
args
,
s
);
}
else
if
(
!
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Col
,
Row
,
Row
>
(
args
,
s
);
}
else
if
(
!
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Col
,
Col
,
Row
>
(
args
,
s
);
}
else
{
throw
std
::
runtime_error
(
"Wrong! Layouts not supported!
\n
"
);
}
}
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <string>
#include <string>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
@@ -51,6 +52,59 @@ using BDataType = Types::BDataType;
...
@@ -51,6 +52,59 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
using
CDataType
=
Types
::
CDataType
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
struct
gemm_traits
{
std
::
string
data_type
;
bool
is_a_rowmajor
;
bool
is_b_rowmajor
;
bool
is_c_rowmajor
;
};
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
AccDataType_
,
typename
CDataType_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
ck_tile
::
index_t
M_Tile_
,
ck_tile
::
index_t
N_Tile_
,
ck_tile
::
index_t
K_Tile_
,
ck_tile
::
index_t
M_Warp_
,
ck_tile
::
index_t
N_Warp_
,
ck_tile
::
index_t
K_Warp_
,
ck_tile
::
index_t
M_Warp_Tile_
,
ck_tile
::
index_t
N_Warp_Tile_
,
ck_tile
::
index_t
K_Warp_Tile_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
>
struct
gemm_traits_
{
using
ADataType
=
ck_tile
::
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
ck_tile
::
remove_cvref_t
<
BDataType_
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
AccDataType_
>
;
using
CDataType
=
ck_tile
::
remove_cvref_t
<
CDataType_
>
;
using
ALayout
=
ck_tile
::
remove_cvref_t
<
ALayout_
>
;
using
BLayout
=
ck_tile
::
remove_cvref_t
<
BLayout_
>
;
using
CLayout
=
ck_tile
::
remove_cvref_t
<
CLayout_
>
;
static
constexpr
ck_tile
::
index_t
M_Tile
=
M_Tile_
;
static
constexpr
ck_tile
::
index_t
N_Tile
=
N_Tile_
;
static
constexpr
ck_tile
::
index_t
K_Tile
=
K_Tile_
;
static
constexpr
ck_tile
::
index_t
M_Warp
=
M_Warp_
;
static
constexpr
ck_tile
::
index_t
N_Warp
=
N_Warp_
;
static
constexpr
ck_tile
::
index_t
K_Warp
=
K_Warp_
;
static
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
M_Warp_Tile_
;
static
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
N_Warp_Tile_
;
static
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
K_Warp_Tile_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
@@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[])
...
@@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[])
}
}
// host API
// host API
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
template
<
typename
Traits_
>
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm
(
const
gemm_traits
&
traits
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/instances/gemm_api.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
FP32
=
float
;
using
FP16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
AccDataType_
,
typename
CDataType_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
ck_tile
::
index_t
M_Tile_
,
ck_tile
::
index_t
N_Tile_
,
ck_tile
::
index_t
K_Tile_
,
ck_tile
::
index_t
M_Warp_
,
ck_tile
::
index_t
N_Warp_
,
ck_tile
::
index_t
K_Warp_
,
ck_tile
::
index_t
M_Warp_Tile_
,
ck_tile
::
index_t
N_Warp_Tile_
,
ck_tile
::
index_t
K_Warp_Tile_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
>
using
trait_
=
gemm_traits_
<
ADataType_
,
BDataType_
,
AccDataType_
,
CDataType_
,
ALayout_
,
BLayout_
,
CLayout_
,
M_Tile_
,
N_Tile_
,
K_Tile_
,
M_Warp_
,
N_Warp_
,
K_Warp_
,
M_Warp_Tile_
,
N_Warp_Tile_
,
K_Warp_Tile_
,
kPadM_
,
kPadN_
,
kPadK_
>
;
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound RR
std
::
cout
<<
"fp16 comp
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound RR
std
::
cout
<<
"fp16 mem
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound RC
std
::
cout
<<
"fp16 comp RC
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound RC
std
::
cout
<<
"fp16 mem RC
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound CR
std
::
cout
<<
"fp16 comp CR
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Col
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound CR
std
::
cout
<<
"fp16 mem CR
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Col
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound CC
std
::
cout
<<
"fp16 comp CC
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Col
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound CC
std
::
cout
<<
"fp16 mem CC
\n
"
;
return
gemm_
<
trait_
<
FP16
,
FP16
,
FP32
,
FP16
,
Col
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
{
throw
std
::
runtime_error
(
"Wrong! ColumnMajor layout not supported for C Matrix!
\n
"
);
}
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound RR
std
::
cout
<<
"bf16 comp
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound RR
std
::
cout
<<
"bf16 mem
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound RC
std
::
cout
<<
"bf16 comp RC
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound RC
std
::
cout
<<
"bf16 mem RC
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound CR
std
::
cout
<<
"bf16 comp CR
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Col
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound CR
std
::
cout
<<
"bf16 mem CR
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Col
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// universal gemm compute bound CC
std
::
cout
<<
"bf16 comp CC
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Col
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
}
else
{
// universal gemm memory bound CC
std
::
cout
<<
"bf16 mem CC
\n
"
;
return
gemm_
<
trait_
<
BF16
,
BF16
,
FP32
,
BF16
,
Col
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
}
}
else
{
throw
std
::
runtime_error
(
"Wrong! ColumnMajor layout not supported for C Matrix!
\n
"
);
}
}
else
{
throw
std
::
runtime_error
(
"Wrong! DataTypes not supported!
\n
"
);
}
return
1.0
f
;
}
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Col
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Col
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Col
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Col
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Row
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Col
,
Row
,
256
,
256
,
32
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <iostream>
#include "gemm_basic.hpp"
using
A
=
ck_tile
::
GemmHostArgs
;
using
S
=
ck_tile
::
stream_config
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
AccDataType_
,
typename
CDataType_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
ck_tile
::
index_t
M_Tile_
,
ck_tile
::
index_t
N_Tile_
,
ck_tile
::
index_t
K_Tile_
,
ck_tile
::
index_t
M_Warp_
,
ck_tile
::
index_t
N_Warp_
,
ck_tile
::
index_t
K_Warp_
,
ck_tile
::
index_t
M_Warp_Tile_
,
ck_tile
::
index_t
N_Warp_Tile_
,
ck_tile
::
index_t
K_Warp_Tile_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
>
using
trait_
=
gemm_traits_
<
ADataType_
,
BDataType_
,
AccDataType_
,
CDataType_
,
ALayout_
,
BLayout_
,
CLayout_
,
M_Tile_
,
N_Tile_
,
K_Tile_
,
M_Warp_
,
N_Warp_
,
K_Warp_
,
M_Warp_Tile_
,
N_Warp_Tile_
,
K_Warp_Tile_
,
kPadM_
,
kPadN_
,
kPadK_
>
;
template
<
typename
Traits_
>
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
Traits_
::
M_Tile
,
Traits_
::
N_Tile
,
Traits_
::
K_Tile
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp
,
Traits_
::
N_Warp
,
Traits_
::
K_Warp
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp_Tile
,
Traits_
::
N_Warp_Tile
,
Traits_
::
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
Traits_
::
AccDataType
,
typename
Traits_
::
CDataType
,
Traits_
::
kPadM
,
Traits_
::
kPadN
>>
;
using
GemmTraits
=
ck_tile
::
TileGemmTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
typename
Traits_
::
ALayout
,
typename
Traits_
::
BLayout
,
typename
Traits_
::
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
>>
;
constexpr
int
kBlockPerCu
=
1
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
Traits_
::
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
Traits_
::
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
UniversalGemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
};
if
(
has_hot_loop
)
{
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
static_assert
(
BaseGemmPipeline
::
PrefetchStages
>
3
);
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
ave_time
;
}
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Col
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Col
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Col
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Col
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
0 → 100644
View file @
896f8b4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
template
float
gemm_
<
trait_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment