Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
923c1700
Commit
923c1700
authored
Aug 29, 2023
by
Rostyslav Geyyer
Browse files
Add bf8 conversion methods
parent
2776c177
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
37 deletions
+52
-37
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+52
-37
No files found.
include/ck/utility/f8_utils.hpp
View file @
923c1700
...
...
@@ -22,16 +22,18 @@ namespace ck::utils {
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
uint8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
constexpr
bool
is_f8_t
=
std
::
is_same
<
Y
,
f8_t
>::
value
;
constexpr
bool
is_bf8_t
=
std
::
is_same
<
Y
,
bf8_t
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8
/bf8
exponent/mantissa layout
constexpr
int
f8_exp
=
is_f8_t
?
4
:
5
;
constexpr
int
f8_mant
=
is_f8_t
?
3
:
2
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
...
...
@@ -40,11 +42,11 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
typedef
typename
std
::
conditional
<
std
::
is_same
<
X
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
...
...
@@ -81,6 +83,15 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
if
(
is_half
&&
is_bf8_t
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
type_mant
);
mantissa
<<=
sh
;
exponent
-=
sh
;
mantissa
&=
~
(
1
<<
type_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
...
...
@@ -132,24 +143,25 @@ __host__ __device__ uint8_t run_cast_to_f8(T x, uint32_t rng)
return
(
sign
<<
(
f8_exp
+
f8_mant
))
|
(
exponent
<<
f8_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
uint8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
constexpr
bool
is_f8_t
=
std
::
is_same
<
X
,
f8_t
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8
/bf8
exponent/mantissa layout
constexpr
int
f8_exp
=
is_f8_t
?
4
:
5
;
constexpr
int
f8_mant
=
is_f8_t
?
3
:
2
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
constexpr
X
nan_code
=
0x80
;
Y
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
...
...
@@ -180,7 +192,7 @@ __host__ __device__ T run_cast_from_f8(uint8_t x)
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
std
::
conditional
<
std
::
is_same
<
Y
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
retval
;
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -216,38 +228,41 @@ __host__ __device__ T run_cast_from_f8(uint8_t x)
}
retval
=
(
sign
<<
(
type_exp
+
type_mant
))
|
(
exponent
<<
type_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
uint8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted to f8."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
constexpr
bool
is_f8
=
std
::
is_same
<
Y
,
f8_t
>::
value
;
constexpr
bool
is_bf8
=
std
::
is_same
<
Y
,
bf8_t
>::
value
;
static_assert
(
is_f8
||
is_bf8
,
"Casting to f8 and bf8 only is supported."
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
uint8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
constexpr
bool
is_f8
=
std
::
is_same
<
X
,
f8_t
>::
value
;
constexpr
bool
is_bf8
=
std
::
is_same
<
X
,
bf8_t
>::
value
;
static_assert
(
is_f8
||
is_bf8
,
"Casting to f8 and bf8 only is supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
static_cast
<
Y
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace ck::utils
// f8_t constuctor impl
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
{
data
=
init
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment