Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e38e61b6
Commit
e38e61b6
authored
Sep 23, 2022
by
Chao Liu
Browse files
update fastgelu
parent
b4f96931
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
35 deletions
+33
-35
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+11
-27
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+22
-8
No files found.
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
e38e61b6
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -225,43 +226,26 @@ struct AddHardswish
...
@@ -225,43 +226,26 @@ struct AddHardswish
};
};
};
};
// C = A * B
// E = FastGelu(C + D)
// E = FastGelu(C + D)
struct
AddFastGelu
struct
AddFastGelu
{
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
template
<
typename
E
,
typename
C
,
typename
D
>
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
is_valid_param_type_v
<
D
>
);
const
float
y
=
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d
));
template
<
>
__host__
__device__
constexpr
void
operator
<
float
,
float
,
float
>
()(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
const
float
x
=
c
+
d
;
e
=
type_convert
<
E
>
(
y
);
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
}
template
<
typename
D
>
template
<
>
__host__
__device__
constexpr
void
operator
()(
floa
t
&
e
,
const
floa
t
&
c
,
const
D
&
d
)
const
__host__
__device__
constexpr
void
operator
<
half_t
,
half_t
,
half_t
>
()(
half_
t
&
e
,
const
half_
t
&
c
,
const
half_t
&
d
)
const
{
{
static_assert
(
is_valid_param_type_v
<
D
>
)
;
const
half_t
x
=
c
+
d
;
e
=
GetFastGeLU
(
c
+
type_convert
<
float
>
(
d
)
);
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
e38e61b6
...
@@ -197,6 +197,8 @@ struct Relu
...
@@ -197,6 +197,8 @@ struct Relu
// https://paperswithcode.com/method/gelu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// device code use lower accuracy "__expf" and "rcp" function
struct
FastGelu
struct
FastGelu
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -205,7 +207,6 @@ struct FastGelu
...
@@ -205,7 +207,6 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
// host code, use higher precision "exp" and "div"
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -216,23 +217,36 @@ struct FastGelu
...
@@ -216,23 +217,36 @@ struct FastGelu
y
=
x
*
cdf
;
y
=
x
*
cdf
;
}
}
template
<
>
__host__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
half_t
>
(
y_f
);
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
#if 0
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
__expf
(
-
u
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
*
__ocml_native_recip_f32
(
float
(
1
)
+
emu
)
-
float
(
1
));
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
*
__ocml_native_recip_f32
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
y
=
x
*
cdf
;
#else
}
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
// device code, use lower precision "__expf" and "rcp"
#endif
template
<
>
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
half_t
>
(
y_f
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment