Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
21481b44
Commit
21481b44
authored
May 08, 2023
by
Rostyslav Geyyer
Browse files
Add fp8<->fp32 type_convert
parent
d3929cb0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
167 additions
and
0 deletions
+167
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+167
-0
No files found.
include/ck/utility/data_type.hpp
View file @
21481b44
...
@@ -1049,6 +1049,173 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -1049,6 +1049,173 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
// fp8 exponent/mantissa layout
constexpr
int
we
=
4
;
constexpr
int
wm
=
3
;
// fp32 exponent/mantissa layout
constexpr
int
weo
=
8
;
constexpr
int
wmo
=
23
;
const
int
mfmt
=
23
;
uint32_t
_x
;
_x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
x
));
uint32_t
head
,
mantissa
;
int
exponent
;
uint32_t
sign
;
head
=
_x
&
0xFF800000
;
mantissa
=
_x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
if
(
negative_zero_nan
)
{
if
((
_x
&
0x7F800000
)
==
0x7F800000
)
return
0x80
;
}
else
{
if
((
_x
&
0x7F800000
)
==
0x7F800000
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
if
(
_x
==
0
)
return
0
;
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
const
int
exp_low_cutoff
=
128
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
mfmt
-
wm
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
mfmt
;
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
mfmt
))
{
mantissa
>>=
1
;
exponent
++
;
}
mantissa
>>=
(
mfmt
-
wm
);
if
(
exponent
<=
0
)
{
if
(
_x
==
0
)
return
0
;
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
exponent
<<
wm
)
|
mantissa
;
}
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
// fp8 exponent/mantissa layout
constexpr
int
we
=
4
;
constexpr
int
wm
=
3
;
// fp32 exponent/mantissa layout
constexpr
int
weo
=
8
;
constexpr
int
wmo
=
23
;
float
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
if
(
x
==
0
)
return
static_cast
<
float
>
(
0
);
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
return
fNaN
;
}
else
{
if
(
x
==
0x80
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa
&=
((
1
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
uint32_t
retval
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
float
*>
(
&
retval
));
}
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment