Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ec959387
Unverified
Commit
ec959387
authored
Feb 13, 2025
by
rocking
Committed by
GitHub
Feb 13, 2025
Browse files
Merge branch 'develop' into ck_tile/fmha_receipt_aiter
parents
c1e2fef7
0e5e29c4
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
327 additions
and
173 deletions
+327
-173
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+5
-0
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+9
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8
-1
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+5
-3
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+1
-1
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+5
-3
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+28
-4
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+1
-1
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+71
-33
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+55
-12
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+75
-106
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+3
-3
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
+14
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+5
-5
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
...ple/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
+13
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
ec959387
...
...
@@ -102,6 +102,11 @@ else()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
# conditionally specify the use of OCP_FP8
if
(
CK_USE_OCP_FP8
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
ec959387
...
...
@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
elif
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'bias'
]
cond
&=
dropout
in
[
'no'
,
'dropout_wg32'
,
'dropout_wg16'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
elif
receipt
==
10
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"batch"
...
...
@@ -818,4 +826,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
ec959387
...
...
@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
if
receipt
in
(
2
,
3
)
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
elif
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'bias'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
elif
receipt
==
10
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
mode
==
"batch"
...
...
example/ck_tile/01_fmha/generate.py
View file @
ec959387
...
...
@@ -103,10 +103,12 @@ if __name__ == "__main__":
required
=
False
,
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 2: Only generate instance for Flash attention integration"
+
\
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration"
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration"
" 2: Only generate instance for Flash attention integration
\n
"
+
\
" 4: Only generate instance for PyTorch integration
\n
"
+
\
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration
\n
"
+
\
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration
\n
"
+
\
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
)
args
=
parser
.
parse_args
()
...
...
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
ec959387
...
...
@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/02_layernorm2d/generate.py
View file @
ec959387
...
...
@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
}
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
...
...
@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits
=
layernorm_fwd_codegen
.
h_traits
h_instance
=
layernorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
]
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit
=
(
'int8'
,
'fp8'
)
types_16bit
=
(
'int16'
,
'fp16'
,
'bf16'
)
#fused_add_list = [0, 1, 2]
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
ec959387
...
...
@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
int8_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1.0
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
xbias
=
arg_parser
.
get_int
(
"xbias"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
)
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
&&
prec_o
!=
"fp8"
)
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
std
::
cout
<<
"if fused_quant is 1 or 2, only support
\"
-prec_o=int8
\"
or
\"
-prec_o=fp8
\"
cases."
<<
std
::
endl
;
return
false
;
}
...
...
@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
constexpr
ComputeDataType
kMaxY
=
std
::
is_same
<
YDataType
,
ck_tile
::
fp8_t
>::
value
?
240.0
:
std
::
is_same
<
YDataType
,
ck_tile
::
int8_t
>::
value
?
127.0
:
0.0
;
ComputeDataType
y_scale
=
absmax
/
kMaxY
;
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
...
...
@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
In
DataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
Out
DataType
>
();
if
(
x_stride
==
n
)
{
...
...
@@ -452,6 +466,16 @@ int main(int argc, char* argv[])
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
;
do
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
ec959387
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
ec959387
...
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
...
...
@@ -20,16 +26,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
...
...
@@ -37,42 +39,33 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
@@ -89,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
...
...
@@ -103,4 +99,46 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -11,21 +11,26 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_COMPUTE
_V3
1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
...
...
@@ -43,6 +48,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf16_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
bf16_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
fp8_t
;
using
BDataType
=
ck_tile
::
fp8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf8_t
>
{
using
ADataType
=
ck_tile
::
bf8_t
;
using
BDataType
=
ck_tile
::
bf8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
typename
T
>
struct
DataTypeTraits
;
...
...
@@ -64,13 +96,23 @@ struct DataTypeTraits<ck_tile::half_t>
static
constexpr
const
char
*
name
=
"fp16"
;
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
DataTypeTraits
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
// Specific type aliases for easy access
using
A
DataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
template
<
>
struct
DataType
Traits
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
}
;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -79,7 +121,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout -
Row
by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout -
Column
by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
...
...
@@ -89,7 +131,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
.
insert
(
"split_k"
,
"1"
,
"splitK value"
)
.
insert
(
"init"
,
"0"
,
"0:random, 1:linear, 2:constant(1)"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
ec959387
...
...
@@ -2,6 +2,14 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
...
...
@@ -22,7 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
...
@@ -48,8 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
...
...
@@ -59,13 +74,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
PrecType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
...
...
@@ -76,6 +94,11 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
...
...
@@ -87,53 +110,32 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_C
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}));
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
)));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
)));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
if
(
init_method
==
0
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
}
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
...
@@ -144,18 +146,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
...
...
@@ -163,15 +166,16 @@ int run_gemm_example_with_layouts(int argc,
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
@@ -180,12 +184,12 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
...
...
@@ -227,8 +231,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
@@ -237,44 +242,8 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU ve
r
ification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
0 → 100644
View file @
ec959387
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
0 → 100644
View file @
ec959387
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
0 → 100644
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
0 → 100644
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
0 → 100644
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
0 → 100644
View file @
ec959387
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment