Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
980b8835
Commit
980b8835
authored
Oct 25, 2023
by
Junhao Zhang
Browse files
compiling time RTN
parent
5fb5335e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
14 deletions
+56
-14
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+10
-14
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+46
-0
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
980b8835
...
...
@@ -51,19 +51,6 @@ struct PassThrough
y
=
x
;
}
#if FLASH_ATTENTION_INTERNAL_USE_RTN
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
y
=
bf16_convert_rtn
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
y
=
bf16_convert_rtn
<
bhalf_t
>
(
x
);
}
#else
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
@@ -75,7 +62,6 @@ struct PassThrough
{
y
=
type_convert
<
bhalf_t
>
(
x
);
}
#endif
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
...
...
@@ -128,6 +114,16 @@ struct PassThrough
}
};
// struct PassThroughRTN : public PassThrough
// {
// __host__ __device__ void operator()(bhalf_t& y, const float& x) const
// {
// y = bf16_convert_rtn<bhalf_t>(x);
// }
// using PassThrough::operator();
// };
struct
UnaryConvert
{
template
<
typename
Y
,
typename
X
>
...
...
include/ck/utility/type_convert.hpp
View file @
980b8835
...
...
@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
}
#if FLASH_ATTENTION_INTERNAL_USE_RTN
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
#else
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
...
...
@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
}
#endif
// convert bfp16 to fp16 via fp32
template
<
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment