Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b4f96931
Commit
b4f96931
authored
Sep 23, 2022
by
Chao Liu
Browse files
improve fastgelu
parent
e9d4e893
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
133 additions
and
7 deletions
+133
-7
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
+42
-1
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
+61
-1
example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
...mm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
+2
-2
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+27
-2
No files found.
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
View file @
b4f96931
...
@@ -33,13 +33,54 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -33,13 +33,54 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = Relu(C + D);
// E = Relu(C + D);
struct
AddRelu
struct
AddRelu
{
{
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
#if 0
template <>
__host__ __device__ void
__host__ __device__ void
operator
()(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
ck
::
half_t
&
d
)
const
operator()
<ck::half_t, ck::half_t, ck::half_t>
(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
{
{
const ck::half_t x = c + d;
const ck::half_t x = c + d;
e = x > 0 ? x : 0;
e = x > 0 ? x : 0;
}
}
#else
// AddFastGeLU
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
ck
::
half_t
&
d
)
const
{
const
ck
::
half_t
x
=
c
+
d
;
e
=
x
>
0
?
x
:
0
;
}
#endif
};
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
};
using
ADataType
=
F16
;
using
ADataType
=
F16
;
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
View file @
b4f96931
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "common.hpp"
#include "common.hpp"
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
...
@@ -19,9 +21,67 @@ using D1Layout = Row;
...
@@ -19,9 +21,67 @@ using D1Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
// C = A * B
// E = FastGelu(C + D0 + D1)
struct
EleFastGeLU
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#if 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#elif
0
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#else
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#endif
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
{
#if 0
const float y =
GetFastGeLU(ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1));
#else
const
float
a
=
ck
::
type_convert
<
float
>
(
c
)
+
ck
::
type_convert
<
float
>
(
d0
)
+
ck
::
type_convert
<
float
>
(
d1
);
const
float
y
=
a
>
0
?
a
:
0
;
#endif
e
=
ck
::
type_convert
<
E
>
(
y
);
}
};
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AddAdd
FastGe
lu
;
using
CDEElementOp
=
Ele
FastGe
LU
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
View file @
b4f96931
...
@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
...
@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
config
.
init_method
)
switch
(
2
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
...
@@ -108,7 +108,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
...
@@ -108,7 +108,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
...
...
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
b4f96931
...
@@ -105,7 +105,7 @@ struct AddAddFastGelu
...
@@ -105,7 +105,7 @@ struct AddAddFastGelu
using
A0ElementOp
=
PassThrough
;
using
A0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
CDE0ElementOp
=
AddAdd
R
elu
;
using
CDE0ElementOp
=
AddAdd
FastG
elu
;
using
A1ElementOp
=
PassThrough
;
using
A1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
b4f96931
...
@@ -10,6 +10,8 @@ namespace ck {
...
@@ -10,6 +10,8 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
struct
PassThrough
struct
PassThrough
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -198,16 +200,39 @@ struct Relu
...
@@ -198,16 +200,39 @@ struct Relu
struct
FastGelu
struct
FastGelu
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
// host code, use higher precision "exp" and "div"
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
// device code, use lower precision "__expf" and "rcp"
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
#if 0
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = __expf(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) * __ocml_native_recip_f32(float(1) + emu) - float(1));
y = x * cdf;
#else
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
y
=
x
*
cdf
;
#endif
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment