Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
536c5458
Commit
536c5458
authored
Aug 25, 2024
by
ThomasNing
Browse files
fix with better naming convention
parent
04006d5f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
34 deletions
+35
-34
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+32
-31
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+3
-3
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
536c5458
...
...
@@ -68,10 +68,10 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// ===============================================
using
Shape
=
ck_tile
::
TileGemmShapeNewGemm
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
Shape
=
ck_tile
::
TileGemmShapeNewGemm
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
Shape
>
;
using
PipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
XDataType
,
YDataType
,
AccDataType
,
Shape
,
kPadA
,
kPadB
,
kPadC
>
;
...
...
@@ -83,9 +83,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
Layouts
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_
x
,
args
.
p_
y
,
args
.
p_
z
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_
a
,
args
.
p_
b
,
args
.
p_
c
,
args
.
batch_size
,
args
.
epsilon
,
args
.
M
,
...
...
@@ -105,9 +105,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
}
template
<
typename
DataType
,
typename
Layouts
>
float
OperatorExecution
(
ck_tile
::
DeviceMem
&
x
_buf
,
ck_tile
::
DeviceMem
&
y
_buf
,
ck_tile
::
DeviceMem
&
z
_buf
,
float
OperatorExecution
(
ck_tile
::
DeviceMem
&
a
_buf
,
ck_tile
::
DeviceMem
&
b
_buf
,
ck_tile
::
DeviceMem
&
c
_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
)
{
...
...
@@ -131,9 +131,9 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
gemm_basic_args
args
;
args
.
p_
x
=
x
_buf
.
GetDeviceBuffer
();
args
.
p_
y
=
y
_buf
.
GetDeviceBuffer
();
args
.
p_
z
=
z
_buf
.
GetDeviceBuffer
();
args
.
p_
a
=
a
_buf
.
GetDeviceBuffer
();
args
.
p_
b
=
b
_buf
.
GetDeviceBuffer
();
args
.
p_
c
=
c
_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
batch_size
=
batch_size
;
args
.
M
=
M
;
...
...
@@ -222,37 +222,38 @@ int main(int argc, char* argv[])
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
constexpr
ck_tile
::
MatrixALayout
matrix_a_layout
=
ck_tile
::
MatrixALayout
::
MK
;
constexpr
ck_tile
::
MatrixBLayout
matrix_b_layout
=
ck_tile
::
MatrixBLayout
::
NK
;
constexpr
ck_tile
::
MatrixCLayout
matrix_c_layout
=
ck_tile
::
MatrixCLayout
::
MN
;
using
Layouts
=
LayoutConfig
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
// host verify
std
::
vector
<
int
>
x
_dimensions
=
(
matrix_a_layout
==
ck_tile
::
MatrixALayout
::
MK
)
std
::
vector
<
int
>
a
_dimensions
=
(
matrix_a_layout
==
ck_tile
::
MatrixALayout
::
MK
)
?
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
y
_dimensions
=
(
matrix_b_layout
==
ck_tile
::
MatrixBLayout
::
NK
)
std
::
vector
<
int
>
b
_dimensions
=
(
matrix_b_layout
==
ck_tile
::
MatrixBLayout
::
NK
)
?
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
z
_dimensions
=
(
matrix_c_layout
==
ck_tile
::
MatrixCLayout
::
MN
)
std
::
vector
<
int
>
c
_dimensions
=
(
matrix_c_layout
==
ck_tile
::
MatrixCLayout
::
MN
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
XDataType
>
x
_host
(
x
_dimensions
);
ck_tile
::
HostTensor
<
YDataType
>
y
_host
(
y
_dimensions
);
ck_tile
::
HostTensor
<
XDataType
>
a
_host
(
a
_dimensions
);
ck_tile
::
HostTensor
<
YDataType
>
b
_host
(
b
_dimensions
);
ck_tile
::
HostTensor
<
ODataType
>
z
_host_ref
(
z
_dimensions
);
ck_tile
::
HostTensor
<
ODataType
>
z
_host_dev
(
z
_dimensions
);
ck_tile
::
HostTensor
<
ODataType
>
c
_host_ref
(
c
_dimensions
);
ck_tile
::
HostTensor
<
ODataType
>
c
_host_dev
(
c
_dimensions
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x
_host
);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
5.
f
,
5.
f
}(
y
_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
a
_host
);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
5.
f
,
5.
f
}(
b
_host
);
ck_tile
::
DeviceMem
x
_buf
(
x
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y
_buf
(
y
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
z
_buf
(
z
_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
a
_buf
(
a
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b
_buf
(
b
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c
_buf
(
c
_host_dev
.
get_element_space_size_in_bytes
());
x
_buf
.
ToDevice
(
x
_host
.
data
());
y
_buf
.
ToDevice
(
y
_host
.
data
());
a
_buf
.
ToDevice
(
a
_host
.
data
());
b
_buf
.
ToDevice
(
b
_host
.
data
());
if
(
grouped_enable
||
following_op_descrp
!=
"no"
)
{
...
...
@@ -260,7 +261,7 @@ int main(int argc, char* argv[])
return
-
1
;
}
OperatorExecution
<
ck_tile
::
half_t
,
Layouts
>
(
x
_buf
,
y
_buf
,
z
_buf
,
arg_parser
);
OperatorExecution
<
ck_tile
::
half_t
,
Layouts
>
(
a
_buf
,
b
_buf
,
c
_buf
,
arg_parser
);
bool
pass
=
true
;
...
...
@@ -268,11 +269,11 @@ int main(int argc, char* argv[])
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
XDataType
,
YDataType
,
AccDataType
,
ODataType
>
(
x
_host
,
y
_host
,
z
_host_ref
,
matrix_a_layout
);
a
_host
,
b
_host
,
c
_host_ref
,
matrix_a_layout
);
z
_buf
.
FromDevice
(
z
_host_dev
.
data
());
c
_buf
.
FromDevice
(
c
_host_dev
.
data
());
pass
=
ck_tile
::
check_err
(
z
_host_dev
,
z
_host_ref
);
pass
=
ck_tile
::
check_err
(
c
_host_dev
,
c
_host_ref
);
std
::
cout
<<
"The veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
536c5458
...
...
@@ -62,9 +62,9 @@ using ODataType = Types::ODataType;
struct
gemm_basic_args
{
const
void
*
p_
x
;
const
void
*
p_
y
;
void
*
p_
z
;
const
void
*
p_
a
;
const
void
*
p_
b
;
void
*
p_
c
;
float
epsilon
;
ck_tile
::
index_t
batch_size
;
ck_tile
::
index_t
M
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment