Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9ba9ebec
Commit
9ba9ebec
authored
Aug 30, 2023
by
Rostyslav Geyyer
Browse files
Fix conversion
parent
846a6773
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
6 deletions
+23
-6
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+23
-6
No files found.
include/ck/utility/f8_utils.hpp
View file @
9ba9ebec
...
@@ -86,9 +86,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
...
@@ -86,9 +86,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
if
(
is_half
&&
is_bf8_t
&&
negative_zero_nan
&&
exponent
==
0
)
if
(
is_half
&&
is_bf8_t
&&
negative_zero_nan
&&
exponent
==
0
)
{
{
exponent
+=
1
;
exponent
+=
1
;
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
type_mant
);
while
(
mantissa
<
(
1
<<
type_mant
))
mantissa
<<=
sh
;
{
exponent
-=
sh
;
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
type_mant
);
mantissa
&=
~
(
1
<<
type_mant
);
}
}
...
@@ -150,6 +152,7 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -150,6 +152,7 @@ __host__ __device__ Y run_cast_from_f8(X x)
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
constexpr
bool
is_f8_t
=
std
::
is_same
<
X
,
f8_t
>::
value
;
constexpr
bool
is_f8_t
=
std
::
is_same
<
X
,
f8_t
>::
value
;
constexpr
bool
is_bf8_t
=
std
::
is_same
<
X
,
bf8_t
>::
value
;
// fp8/bf8 exponent/mantissa layout
// fp8/bf8 exponent/mantissa layout
constexpr
int
f8_exp
=
is_f8_t
?
4
:
5
;
constexpr
int
f8_exp
=
is_f8_t
?
4
:
5
;
...
@@ -185,6 +188,10 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -185,6 +188,10 @@ __host__ __device__ Y run_cast_from_f8(X x)
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
}
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
// unpack the input
uint32_t
sign
=
x
>>
(
f8_exp
+
f8_mant
);
uint32_t
sign
=
x
>>
(
f8_exp
+
f8_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8_mant
)
-
1
);
uint32_t
mantissa
=
x
&
((
1
<<
f8_mant
)
-
1
);
...
@@ -207,14 +214,24 @@ __host__ __device__ Y run_cast_from_f8(X x)
...
@@ -207,14 +214,24 @@ __host__ __device__ Y run_cast_from_f8(X x)
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
if
(
is_bf8_t
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
exponent
++
;
mantissa
<<=
sh
;
while
(
mantissa
<
(
1
<<
f8_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
exponent
+=
1
-
sh
;
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type_mant
-
f8_mant
;
mantissa
<<=
type_mant
-
f8_mant
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment