Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
532bbe53
Commit
532bbe53
authored
May 23, 2023
by
Rostyslav Geyyer
Browse files
Add fp16 casting functions
parent
c1ba7c63
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
132 additions
and
74 deletions
+132
-74
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+132
-74
No files found.
include/ck/utility/f8_utils.hpp
View file @
532bbe53
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
namespace
ck
{
namespace
ck
{
using
f8_t
=
uint8_t
;
using
f8_t
=
uint8_t
;
using
half_t
=
_Float16
;
// fp8 rounding modes
// fp8 rounding modes
enum
class
f8_rounding_mode
enum
class
f8_rounding_mode
...
@@ -16,66 +17,81 @@ enum class f8_rounding_mode
...
@@ -16,66 +17,81 @@ enum class f8_rounding_mode
stochastic
stochastic
};
};
// cast fp32 to fp8
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
f8_t
cast_to_f8
(
float
x
,
uint32_t
rng
)
{
{
//
fp8 exponent/mantissa layout
//
check data type
constexpr
int
we_f8
=
4
;
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
int
wm_f8
=
3
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp
32
exponent/mantissa layout
// fp
8
exponent/mantissa layout
constexpr
int
we_f32
=
8
;
constexpr
int
f8_exp
=
4
;
constexpr
int
wm_f32
=
2
3
;
constexpr
int
f8_mant
=
3
;
uint32_t
x_bitwise
;
// resulting type exponent/mantissa layout
x_bitwise
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
x
));
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// unpack the input
uint32_t
head
,
mantissa
;
int
exponent
;
int
exponent
;
uint32_t
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
head
=
x_bitwise
&
0xFF800000
;
// convert to bitwise
mantissa
=
x_bitwise
&
0x7FFFFF
;
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
exponent
=
(
head
>>
wm_f32
)
&
0xFF
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
sign
=
head
>>
(
we_f32
+
wm_f32
);
uint32_t
signed_inf
=
(
sign
<<
(
we_f8
+
wm_f8
))
+
(((
1
<<
we_f8
)
-
1
)
<<
wm_f8
);
// unpack the input, depends on datatype
uint32_t
drop_mask
=
(
1
<<
(
wm_f32
-
wm_f8
))
-
1
;
if
constexpr
(
is_float
)
int
max_exp
;
{
int
exp_low_cutoff
;
head
=
x_bitwise
&
0xFF800000
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
else
if
constexpr
(
is_half
)
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
if
((
x_bitwise
&
0x7F800000
)
==
0x7F800000
)
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
0x80
;
return
nan_code
;
max_exp
=
(
1
<<
we_f8
)
-
1
;
exp_low_cutoff
=
0x80
-
(
1
<<
(
we_f8
-
1
));
}
}
else
else
{
{
if
((
x_bitwise
&
0x7F800000
)
==
0x7F800000
)
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
max_exp
=
(
1
<<
we_f8
)
-
2
;
exp_low_cutoff
=
0x80
-
(
1
<<
(
we_f8
-
1
))
+
1
;
}
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
wm_f32
-
wm_f8
+
1
-
exponent
))
-
1
;
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
wm_f32
;
mantissa
+=
1
<<
type_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
wm_f32
))
if
(
mantissa
>=
(
2
<<
type_mant
))
{
{
mantissa
>>=
1
;
mantissa
>>=
1
;
exponent
++
;
exponent
++
;
}
}
mantissa
>>=
(
wm_f32
-
wm_f8
);
mantissa
>>=
(
type_mant
-
f8_mant
);
// check negative exponent
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
...
@@ -93,7 +109,7 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
...
@@ -93,7 +109,7 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
wm_f8
)
-
1
;
mantissa
=
(
1
<<
f8_mant
)
-
1
;
exponent
=
max_exp
;
exponent
=
max_exp
;
}
}
else
else
...
@@ -101,65 +117,92 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
...
@@ -101,65 +117,92 @@ __host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
return
signed_inf
;
return
signed_inf
;
}
}
}
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
we_f8
+
wm_f8
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8_exp
+
f8_mant
));
mantissa
&=
(
1
<<
wm_f8
)
-
1
;
mantissa
&=
(
1
<<
f8_mant
)
-
1
;
return
(
sign
<<
(
we_f8
+
wm_f8
))
|
(
exponent
<<
wm_f8
)
|
mantissa
;
return
(
sign
<<
(
f8_exp
+
f8_mant
))
|
(
exponent
<<
f8_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted to f8."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
}
// cast fp8 to fp32
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
__host__
__device__
float
cast_from_f8
(
f8_t
x
)
{
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
// fp8 exponent/mantissa layout
constexpr
int
we_f8
=
4
;
constexpr
int
f8_exp
=
4
;
constexpr
int
wm_f8
=
3
;
constexpr
int
f8_mant
=
3
;
// fp32 exponent/mantissa layout
constexpr
int
we_f32
=
8
;
constexpr
int
wm_f32
=
23
;
float
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
if
(
x
==
0
)
// resulting type exponent/mantissa layout
return
static_cast
<
float
>
(
0
);
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
constexpr
uint32_t
ifInf
=
0x7F800000
;
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
// unpack the input
// unpack the input
uint32_t
sign
=
x
>>
(
we_f8
+
wm_f8
);
uint32_t
sign
=
x
>>
(
f8_exp
+
f8_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
wm_f8
)
-
1
);
uint32_t
mantissa
=
x
&
((
1
<<
f8_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm_f8
;
int
exponent
=
(
x
&
0x7F
)
>>
f8_mant
;
int
exp_low_cutoff
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
)
;
uint32_t
retval
;
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
retval
;
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
if
(
x
==
0x80
)
if
(
x
==
nan_code
)
return
fNaN
;
return
fNaN
;
exp_low_cutoff
=
(
1
<<
(
we_f32
-
1
))
-
(
1
<<
(
we_f8
-
1
));
}
}
else
else
{
{
if
(
x
==
0x80
)
if
(
x
==
nan_code
)
return
fNeg0
;
return
fNeg0
;
if
(
exponent
==
((
1
<<
we_f8
)
-
1
))
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
exp_low_cutoff
=
(
1
<<
(
we_f32
-
1
))
-
(
1
<<
(
we_f8
-
1
))
+
1
;
}
}
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
we_f32
+
wm_f32
)
-
wm_f8
);
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
mantissa
<<=
sh
;
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
exponent
+=
1
-
sh
;
/*
/*
...
@@ -169,21 +212,36 @@ __host__ __device__ float cast_from_f8(f8_t x)
...
@@ -169,21 +212,36 @@ __host__ __device__ float cast_from_f8(f8_t x)
exponent--;
exponent--;
}
}
*/
*/
mantissa
&=
((
1
<<
wm_f8
)
-
1
);
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wm_f32
-
wm_f8
;
mantissa
<<=
type_mant
-
f8_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
wm_f32
;
mantissa
|=
1
<<
type_mant
;
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
exponent
=
0
;
}
}
retval
=
(
sign
<<
(
we_f32
+
wm_f32
))
|
(
exponent
<<
wm_f32
)
|
mantissa
;
retval
=
(
sign
<<
(
type_exp
+
type_mant
))
|
(
exponent
<<
type_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
float
*>
(
&
retval
));
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment