Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f306024
Unverified
Commit
1f306024
authored
Feb 07, 2024
by
Lakhinder Walia
Committed by
GitHub
Feb 07, 2024
Browse files
fast_gelu: minor code reorg to enhance ref & gpu performance (#1162)
parent
1b0fbaeb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+13
-11
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
1f306024
...
@@ -458,27 +458,29 @@ struct FastGelu
...
@@ -458,27 +458,29 @@ struct FastGelu
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
emu
=
exp
(
-
u
);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
y
=
x
*
cdf
;
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const
float
emu
=
__expf
(
-
u
);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__expf
(
u
);
#if !CK_WORKAROUND_SWDEV_383542
#if !CK_WORKAROUND_SWDEV_383542
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
)
;
y
=
x
*
__frcp_rn
(
1.
f
+
emu
);
#else
#else
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
)
;
y
=
x
*
__ocml_native_recip_f32
(
1.
f
+
emu
);
#endif
#endif
y
=
x
*
cdf
;
}
}
template
<
>
template
<
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment