Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e1e8e1ad
"driver/vscode:/vscode.git/clone" did not exist on "9d99a5807298c3f263d39a08328c3c68c930a900"
Commit
e1e8e1ad
authored
Jan 12, 2025
by
Aleksander Dudek
Browse files
[CK_TILE] Use the GEMM prec input arg
parent
3cad16c4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
96 additions
and
26 deletions
+96
-26
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+29
-7
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+23
-8
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+44
-11
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
e1e8e1ad
...
...
@@ -11,9 +11,17 @@
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataTypeConfig
>
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
using
Types
=
GemmBasicTypeConfig
<
DataTypeConfig
>
;
// Specific type aliases for easy access
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
AccDataType
=
typename
Types
::
AccDataType
;
using
CDataType
=
typename
Types
::
CDataType
;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
...
...
@@ -100,23 +108,24 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
return
ave_time
;
}
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
template
<
typename
DataType
>
float
gemm_type_
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Row
,
Row
,
Row
>
(
args
,
s
);
return
gemm_
<
Row
,
Row
,
Row
,
DataType
>
(
args
,
s
);
}
else
if
(
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Row
,
Col
,
Row
>
(
args
,
s
);
return
gemm_
<
Row
,
Col
,
Row
,
DataType
>
(
args
,
s
);
}
else
if
(
!
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Col
,
Row
,
Row
>
(
args
,
s
);
return
gemm_
<
Col
,
Row
,
Row
,
DataType
>
(
args
,
s
);
}
else
if
(
!
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
return
gemm_
<
Col
,
Col
,
Row
>
(
args
,
s
);
return
gemm_
<
Col
,
Col
,
Row
,
DataType
>
(
args
,
s
);
}
else
{
...
...
@@ -124,6 +133,19 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
}
}
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
==
"fp16"
)
{
return
gemm_type_
<
GemmFp16
>
(
t
,
args
,
s
);
}
else
if
(
t
.
data_type
==
"bf16"
)
{
return
gemm_type_
<
GemmBf16
>
(
t
,
args
,
s
);
}
else
{
throw
std
::
runtime_error
(
"Wrong! Data type not supported!
\n
"
);
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -137,7 +159,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"f
p
16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"
b
f16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
e1e8e1ad
...
...
@@ -10,11 +10,19 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
struct
GemmFp16
{
};
struct
GemmBf16
{
};
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
struct
GemmBasicTypeConfig
<
GemmFp16
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
...
...
@@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template
<
>
struct
GemmBasicTypeConfig
<
GemmBf16
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
bf16_t
;
};
template
<
typename
T
>
struct
DataTypeTraits
;
...
...
@@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t>
static
constexpr
const
char
*
name
=
"fp16"
;
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
e1e8e1ad
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataTypeT
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
...
@@ -16,6 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
using
Types
=
GemmBasicTypeConfig
<
DataTypeT
>
;
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
CDataType
=
typename
Types
::
CDataType
;
ck_tile
::
GemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
...
...
@@ -50,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DataTypeT
>
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
...
...
@@ -61,6 +66,12 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
return
-
1
;
using
Types
=
GemmBasicTypeConfig
<
DataTypeT
>
;
using
ADataType
=
typename
Types
::
ADataType
;
using
BDataType
=
typename
Types
::
BDataType
;
using
AccDataType
=
typename
Types
::
AccDataType
;
using
CDataType
=
typename
Types
::
CDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
...
...
@@ -129,7 +140,7 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
,
DataTypeT
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
...
...
@@ -209,33 +220,55 @@ int run_gemm_example_with_layouts(int argc,
return
pass
;
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
template
<
typename
DataType
>
int
run_gemm_example_with_datatype
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
a_layout
,
const
std
::
string
&
b_layout
)
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
<
Row
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
<
Row
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
<
Col
,
Col
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
<
Col
,
Row
,
Row
,
DataType
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
prec
=
arg_parser
.
get_str
(
"prec"
);
if
(
prec
==
"fp16"
)
{
return
run_gemm_example_with_datatype
<
GemmFp16
>
(
argc
,
argv
,
a_layout
,
b_layout
);
}
else
if
(
prec
==
"bf16"
)
{
return
run_gemm_example_with_datatype
<
GemmBf16
>
(
argc
,
argv
,
a_layout
,
b_layout
);
}
else
{
throw
std
::
runtime_error
(
"Unsupported data type!"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment