Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dc9c9784
Commit
dc9c9784
authored
Nov 08, 2023
by
Umang Yadav
Browse files
make bit_cast a function
parent
770b632d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
25 deletions
+19
-25
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+19
-25
No files found.
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
dc9c9784
...
...
@@ -41,6 +41,18 @@ struct conditional<false, T, F>
{
using
type
=
F
;
};
template
<
typename
To
,
typename
From
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
#if defined(__GNUC__) and !defined(__clang__)
To
x
=
CONST_FOLD
(
*
reinterpret_cast
<
To
*>
(
&
fr
));
#else
To
x
=
__builtin_bit_cast
(
To
,
fr
);
#endif
return
x
;
}
}
// namespace detail
// #ifdef __HIP_PLATFORM_HCC__
...
...
@@ -58,17 +70,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
#if defined(__GNUC__) and !defined(__clang__)
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
CONST_FOLD
(
*
reinterpre
t_cast
<
uint32_t
*
>
(
&
_x
)
)
;
x
=
detail
::
bi
t_cast
<
uint32_t
>
(
_x
);
else
x
=
CONST_FOLD
(
*
reinterpret_cast
<
uint16_t
*>
(
&
_x
));
#else
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
__builtin_bit_cast
(
uint32_t
,
_x
);
else
x
=
__builtin_bit_cast
(
uint16_t
,
_x
);
#endif
x
=
detail
::
bit_cast
<
uint16_t
>
(
_x
);
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
...
...
@@ -246,18 +251,11 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
uint32_t
ifNegInf
=
0xFF800000
;
uint32_t
ifNaN
=
0x7F800001
;
uint32_t
ifNeg0
=
0x80000000
;
#if defined(__GNUC__) and !defined(__clang__)
fInf
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifInf
)));
fNegInf
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNegInf
)));
fNaN
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNaN
)));
fNeg0
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNeg0
)));
#else
// TODO: need to change T for half but right now it would never be called with half
fInf
=
__builtin_bit_cast
(
float
,
ifInf
);
fNegInf
=
__builtin_bit_cast
(
float
,
ifNegInf
);
fNaN
=
__builtin_bit_cast
(
float
,
ifNaN
);
fNeg0
=
__builtin_bit_cast
(
float
,
ifNeg0
);
#endif
fInf
=
detail
::
bit_cast
<
float
>
(
ifInf
);
fNegInf
=
detail
::
bit_cast
<
float
>
(
ifNegInf
);
fNaN
=
detail
::
bit_cast
<
float
>
(
ifNaN
);
fNeg0
=
detail
::
bit_cast
<
float
>
(
ifNeg0
);
if
(
x
==
0
)
return
0
;
...
...
@@ -305,11 +303,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
#if defined(__GNUC__) and !defined(__clang__)
return
CONST_FOLD
(
*
reinterpret_cast
<
T
*>
(
&
retval
));
#else
return
__builtin_bit_cast
(
T
,
retval
);
#endif
return
detail
::
bit_cast
<
T
>
(
retval
);
}
}
// namespace migraphx_hip_f8_impl
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment