Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
87f44ead
Commit
87f44ead
authored
Feb 15, 2023
by
Chao Liu
Browse files
clean
parent
4198de5a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
128 deletions
+3
-128
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+0
-42
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+0
-40
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-45
No files found.
include/ck/ck.hpp
View file @
87f44ead
...
@@ -169,7 +169,7 @@
...
@@ -169,7 +169,7 @@
#define CK_WORKAROUND_SWDEV_325164 0
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_
XXXXXX_FRCP_RN
1
#define CK_WORKAROUND_SWDEV_
383542
1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
87f44ead
...
@@ -280,47 +280,6 @@ struct AddHardswish
...
@@ -280,47 +280,6 @@ struct AddHardswish
};
};
};
};
#if 0
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
#else
// E = FastGelu(C + D)
// E = FastGelu(C + D)
struct
AddFastGelu
struct
AddFastGelu
{
{
...
@@ -345,7 +304,6 @@ struct AddFastGelu
...
@@ -345,7 +304,6 @@ struct AddFastGelu
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
}
};
};
#endif
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
87f44ead
...
@@ -195,45 +195,6 @@ struct AddMultiply
...
@@ -195,45 +195,6 @@ struct AddMultiply
}
}
};
};
#if 0
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
e = type_convert<E>(y);
}
};
#else
// E = FastGelu(C + D0 + D1)
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
struct
AddAddFastGelu
{
{
...
@@ -261,7 +222,6 @@ struct AddAddFastGelu
...
@@ -261,7 +222,6 @@ struct AddAddFastGelu
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
}
};
};
#endif
struct
Normalize
struct
Normalize
{
{
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
87f44ead
...
@@ -11,7 +11,7 @@ namespace ck {
...
@@ -11,7 +11,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
#if CK_WORKAROUND_SWDEV_
XXXXXX_FRCP_RN
#if CK_WORKAROUND_SWDEV_
383542
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
#endif
...
@@ -204,40 +204,6 @@ struct Relu
...
@@ -204,40 +204,6 @@ struct Relu
}
}
};
};
#if 0
// Y = FastGelu(X)
struct FastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
}
};
#else
// Fast GeLU
// Fast GeLU
// https://paperswithcode.com/method/gelu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
...
@@ -275,24 +241,16 @@ struct FastGelu
...
@@ -275,24 +241,16 @@ struct FastGelu
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
const
float
emu
=
__expf
(
-
u
);
#if !CK_WORKAROUND_SWDEV_
XXXXXX_FRCP_RN
#if !CK_WORKAROUND_SWDEV_
383542
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
);
#else
#else
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
#endif
#endif
y
=
x
*
cdf
;
y
=
x
*
cdf
;
#endif
}
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
...
@@ -306,7 +264,6 @@ struct FastGelu
...
@@ -306,7 +264,6 @@ struct FastGelu
y
=
type_convert
<
half_t
>
(
y_f
);
y
=
type_convert
<
half_t
>
(
y_f
);
}
}
};
};
#endif
// https://paperswithcode.com/method/gelu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
// y = 0.5*x*(1+erf(x/sqrt(2)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment