Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7fd0378a
Commit
7fd0378a
authored
Jun 07, 2023
by
Rostyslav Geyyer
Browse files
Add gemm_fastgelu client example
parent
f852625e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
160 additions
and
1 deletion
+160
-1
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
+4
-1
client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp
...xample/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp
+156
-0
No files found.
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
View file @
7fd0378a
...
...
@@ -20,5 +20,8 @@ target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_ke
add_executable
(
client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations
)
add_executable
(
client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations
)
add_dependencies
(
client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic
)
client_gemm_add_fastgelu_generic
client_gemm_fastgelu_generic
)
client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp
0 → 100644
View file @
7fd0378a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
FastGelu
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideE
=
4096
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
7
)
{
M
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
2
]);
K
=
std
::
stoi
(
argv
[
3
]);
StrideA
=
std
::
stoi
(
argv
[
4
]);
StrideB
=
std
::
stoi
(
argv
[
5
]);
StrideE
=
std
::
stoi
(
argv
[
6
]);
}
else
{
printf
(
"arg1 to 6: M, N, K, StrideA, StrideB, StrideE
\n
"
);
exit
(
0
);
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
M
,
K
,
StrideA
,
ALayout
{}));
SimpleDeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
K
,
N
,
StrideB
,
BLayout
{}));
SimpleDeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
M
,
N
,
StrideE
,
ELayout
{}));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
FastGelu
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
// get generic instance
auto
&
op_ptr
=
op_ptrs
[
0
];
std
::
cout
<<
"Run the generic instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
// run the generic instance
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
{},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
return
0
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment