Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
502942fe
Commit
502942fe
authored
May 23, 2023
by
Rostyslav Geyyer
Browse files
Format
parent
532bbe53
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
21 deletions
+24
-21
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+24
-21
No files found.
include/ck/utility/f8_utils.hpp
View file @
502942fe
...
...
@@ -7,8 +7,8 @@
namespace
ck
{
using
f8_t
=
uint8_t
;
using
half_t
=
_Float16
;
using
f8_t
=
uint8_t
;
using
half_t
=
_Float16
;
// fp8 rounding modes
enum
class
f8_rounding_mode
...
...
@@ -25,21 +25,22 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
...
...
@@ -58,10 +59,11 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -144,11 +146,11 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// prepare the codes
...
...
@@ -160,10 +162,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
...
...
@@ -171,10 +173,10 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
// unpack the input
...
...
@@ -182,7 +184,8 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
uint32_t
mantissa
=
x
&
((
1
<<
f8_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
retval
;
if
constexpr
(
negative_zero_nan
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment