Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fc1f07ac
Commit
fc1f07ac
authored
Sep 23, 2022
by
Chao Liu
Browse files
update fastgelu
parent
e38e61b6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
46 additions
and
158 deletions
+46
-158
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
+1
-42
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
+4
-63
example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
...mm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
+2
-2
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+6
-4
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+20
-42
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+12
-4
No files found.
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
View file @
fc1f07ac
...
@@ -33,54 +33,13 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -33,54 +33,13 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = Relu(C + D);
// E = Relu(C + D);
struct
AddRelu
struct
AddRelu
{
{
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
#if 0
template <>
__host__
__device__
void
__host__
__device__
void
operator()
<ck::half_t, ck::half_t, ck::half_t>
(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
operator
()(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
ck
::
half_t
&
d
)
const
{
{
const
ck
::
half_t
x
=
c
+
d
;
const
ck
::
half_t
x
=
c
+
d
;
e
=
x
>
0
?
x
:
0
;
e
=
x
>
0
?
x
:
0
;
}
}
#else
// AddFastGeLU
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
ck
::
half_t
&
d
)
const
{
const
ck
::
half_t
x
=
c
+
d
;
e
=
x
>
0
?
x
:
0
;
}
#endif
};
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y
=
x
*
cdf
;
}
};
};
using
ADataType
=
F16
;
using
ADataType
=
F16
;
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
View file @
fc1f07ac
...
@@ -3,12 +3,11 @@
...
@@ -3,12 +3,11 @@
#include "common.hpp"
#include "common.hpp"
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
CDataType
=
F16
;
// C matrix doesn't exsitm this is used for verification
using
D0DataType
=
F16
;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
@@ -21,67 +20,9 @@ using D1Layout = Row;
...
@@ -21,67 +20,9 @@ using D1Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
// C = A * B
// E = FastGelu(C + D0 + D1)
struct
EleFastGeLU
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#if 0
__device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
#elif
0
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#else
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
#endif
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
{
#if 0
const float y =
GetFastGeLU(ck::type_convert<float>(c) + ck::type_convert<float>(d0) + ck::type_convert<float>(d1));
#else
const
float
a
=
ck
::
type_convert
<
float
>
(
c
)
+
ck
::
type_convert
<
float
>
(
d0
)
+
ck
::
type_convert
<
float
>
(
d1
);
const
float
y
=
a
>
0
?
a
:
0
;
#endif
e
=
ck
::
type_convert
<
E
>
(
y
);
}
};
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Ele
FastGe
LU
;
using
CDEElementOp
=
AddAdd
FastGe
lu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
@@ -96,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
...
@@ -96,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
Acc
DataType
,
C
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
...
example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc
View file @
fc1f07ac
...
@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
...
@@ -41,7 +41,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
2
)
switch
(
config
.
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
...
@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
...
@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
Tensor
<
Acc
DataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
C
DataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
fc1f07ac
...
@@ -105,7 +105,7 @@ struct AddAddFastGelu
...
@@ -105,7 +105,7 @@ struct AddAddFastGelu
using
A0ElementOp
=
PassThrough
;
using
A0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
CDE0ElementOp
=
AddAdd
FastG
elu
;
using
CDE0ElementOp
=
AddAdd
R
elu
;
using
A1ElementOp
=
PassThrough
;
using
A1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
fc1f07ac
...
@@ -232,16 +232,18 @@ struct AddFastGelu
...
@@ -232,16 +232,18 @@ struct AddFastGelu
template
<
typename
E
,
typename
C
,
typename
D
>
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
<
float
,
float
,
float
>
()(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
{
const
float
x
=
c
+
d
;
const
float
x
=
c
+
d
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
}
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
<
half_t
,
half_t
,
half_t
>
()(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
{
{
const
half_t
x
=
c
+
d
;
const
half_t
x
=
c
+
d
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
fc1f07ac
...
@@ -15,7 +15,7 @@ namespace element_wise {
...
@@ -15,7 +15,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
// siliently do implicit type conversion
//
//
//
Method 1
:
//
Example
:
//
//
// struct ExampleElementwiseOp
// struct ExampleElementwiseOp
// {
// {
...
@@ -29,19 +29,6 @@ namespace element_wise {
...
@@ -29,19 +29,6 @@ namespace element_wise {
// {
// {
// }
// }
// };
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct
AddReluAdd
struct
AddReluAdd
{
{
...
@@ -142,7 +129,6 @@ struct AddHardswishAdd
...
@@ -142,7 +129,6 @@ struct AddHardswishAdd
}
}
};
};
// C = A * B
// E = C + D0 + D1
// E = C + D0 + D1
struct
AddAdd
struct
AddAdd
{
{
...
@@ -171,41 +157,33 @@ struct AddAdd
...
@@ -171,41 +157,33 @@ struct AddAdd
}
}
};
};
// C = A * B
// E = FastGelu(C + D0 + D1)
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
struct
AddAddFastGelu
{
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
#endif
;
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
const
float
x
=
c
+
d0
+
d1
;
is_valid_param_type_v
<
D0
>
&&
is_valid_param_type_v
<
D1
>
);
const
float
y
=
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
));
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
x
=
c
+
d0
+
d1
;
e
=
type_convert
<
E
>
(
y
);
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
fc1f07ac
...
@@ -210,9 +210,9 @@ struct FastGelu
...
@@ -210,9 +210,9 @@ struct FastGelu
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
)
);
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
)
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
y
=
x
*
cdf
;
y
=
x
*
cdf
;
}
}
...
@@ -231,11 +231,19 @@ struct FastGelu
...
@@ -231,11 +231,19 @@ struct FastGelu
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
emu
=
__expf
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
*
__ocml_native_recip_f32
(
float
(
1
)
+
emu
)
-
float
(
1
)
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
y
=
x
*
cdf
;
y
=
x
*
cdf
;
#endif
}
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment