Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
18d2bb1b
"...pyexps/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "ced8b8d141709741b03af7d6ffe44650c656efe7"
Commit
18d2bb1b
authored
May 18, 2022
by
Chao Liu
Browse files
ad gelu and fast_gelu
parent
9f71ff48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
0 deletions
+42
-0
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+42
-0
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
18d2bb1b
...
...
@@ -27,6 +27,44 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
Gelu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const
float
a
=
float
(
0.035677
)
*
x
*
x
;
const
float
b
=
float
(
0.797885
)
+
a
;
const
float
c
=
b
*
x
;
const
float
d
=
tanh
(
c
);
const
float
e
=
float
(
1.0
)
+
d
;
y
=
float
(
0.5
)
*
x
*
e
;
}
};
struct
FastGelu
{
__host__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X))
const
float
a
=
float
(
0.035677
)
*
x
*
x
;
const
float
b
=
float
(
0.797885
)
+
a
;
const
float
c
=
b
*
x
;
const
float
d
=
tanh
(
c
);
const
float
e
=
float
(
1.0
)
+
d
;
y
=
float
(
0.5
)
*
x
*
e
;
}
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
...
...
@@ -38,7 +76,11 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#if 0
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
#else
using
CElementOp
=
FastGelu
;
#endif
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment