Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3de7bd67
Commit
3de7bd67
authored
Jan 12, 2025
by
Aleksander Dudek
Browse files
[CK_TILE] Use the GEMM example prec input arg
parent
e1e8e1ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
32 deletions
+44
-32
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+10
-5
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+34
-27
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
3de7bd67
...
...
@@ -109,7 +109,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
}
template
<
typename
DataType
>
float
gemm_type_
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_type_
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
...
...
@@ -135,13 +137,16 @@ float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
==
"fp16"
)
{
if
(
t
.
data_type
==
"fp16"
)
{
return
gemm_type_
<
GemmFp16
>
(
t
,
args
,
s
);
}
else
if
(
t
.
data_type
==
"bf16"
)
{
else
if
(
t
.
data_type
==
"bf16"
)
{
return
gemm_type_
<
GemmBf16
>
(
t
,
args
,
s
);
}
else
{
else
{
throw
std
::
runtime_error
(
"Wrong! Data type not supported!
\n
"
);
}
}
...
...
@@ -159,7 +164,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"
b
f16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"f
p
16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
3de7bd67
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataType
T
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataType
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
...
@@ -16,10 +16,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
using
Types
=
GemmBasicTypeConfig
<
DataType
T
>
;
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
CDataType
=
typename
Types
::
CDataType
;
using
Types
=
GemmBasicTypeConfig
<
DataType
>
;
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
CDataType
=
typename
Types
::
CDataType
;
ck_tile
::
GemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
...
...
@@ -55,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataType
T
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataType
>
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
...
...
@@ -66,7 +66,7 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
return
-
1
;
using
Types
=
GemmBasicTypeConfig
<
DataType
T
>
;
using
Types
=
GemmBasicTypeConfig
<
DataType
>
;
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
AccDataType
=
typename
Types
::
AccDataType
;
...
...
@@ -140,18 +140,18 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
,
DataType
T
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
,
DataType
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
kbatch
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
...
...
@@ -221,7 +221,10 @@ int run_gemm_example_with_layouts(int argc,
}
template
<
typename
DataType
>
int
run_gemm_example_with_datatype
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
a_layout
,
const
std
::
string
&
b_layout
)
int
run_gemm_example_with_datatype
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
a_layout
,
const
std
::
string
&
b_layout
)
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
...
...
@@ -229,19 +232,23 @@ int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
<
Row
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
<
Row
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
<
Row
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
<
Row
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
<
Col
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
<
Col
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
<
Col
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
<
Col
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
...
...
@@ -257,9 +264,10 @@ int run_gemm_example(int argc, char* argv[])
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
prec
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec
=
arg_parser
.
get_str
(
"prec"
);
if
(
prec
==
"fp16"
)
{
if
(
prec
==
"fp16"
)
{
return
run_gemm_example_with_datatype
<
GemmFp16
>
(
argc
,
argv
,
a_layout
,
b_layout
);
}
else
if
(
prec
==
"bf16"
)
...
...
@@ -270,5 +278,4 @@ int run_gemm_example(int argc, char* argv[])
{
throw
std
::
runtime_error
(
"Unsupported data type!"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment