Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8a913c22
Commit
8a913c22
authored
May 18, 2022
by
Chao Liu
Browse files
added GeLU and fast GeLU
parent
18d2bb1b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
42 deletions
+27
-42
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+0
-41
example/CMakeLists.txt
example/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+26
-0
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
8a913c22
...
@@ -27,43 +27,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -27,43 +27,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
Gelu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const
float
a
=
float
(
0.035677
)
*
x
*
x
;
const
float
b
=
float
(
0.797885
)
+
a
;
const
float
c
=
b
*
x
;
const
float
d
=
tanh
(
c
);
const
float
e
=
float
(
1.0
)
+
d
;
y
=
float
(
0.5
)
*
x
*
e
;
}
};
struct
FastGelu
{
__host__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const
float
a
=
float
(
0.035677
)
*
x
*
x
;
const
float
b
=
float
(
0.797885
)
+
a
;
const
float
c
=
b
*
x
;
const
float
d
=
tanh
(
c
);
const
float
e
=
float
(
1.0
)
+
d
;
y
=
float
(
0.5
)
*
x
*
e
;
}
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
...
@@ -76,11 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -76,11 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#if 0
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#else
using
CElementOp
=
FastGelu
;
#endif
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
example/CMakeLists.txt
View file @
8a913c22
...
@@ -33,7 +33,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -33,7 +33,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
endfunction
(
add_example_executable
_no_testing
EXAMPLE_NAME
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
02_gemm_alpha_beta
)
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
8a913c22
...
@@ -20,6 +20,32 @@ struct PassThrough
...
@@ -20,6 +20,32 @@ struct PassThrough
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
};
};
struct
Gelu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const
float
a
=
float
(
0.035677
)
*
x
*
x
;
const
float
b
=
float
(
0.797885
)
+
a
;
const
float
c
=
b
*
x
;
const
float
d
=
tanh
(
c
);
const
float
e
=
float
(
1.0
)
+
d
;
y
=
float
(
0.5
)
*
x
*
e
;
}
};
struct
FastGelu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
struct
Add
struct
Add
{
{
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment