Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3b0c7dc8
Commit
3b0c7dc8
authored
Nov 08, 2023
by
Umang Yadav
Browse files
Works with GCC
parent
a298c926
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
17 deletions
+44
-17
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+5
-1
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+39
-16
No files found.
src/include/migraphx/migraphx_float8.hpp
View file @
3b0c7dc8
...
@@ -22,12 +22,13 @@
...
@@ -22,12 +22,13 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
...
@@ -428,6 +429,7 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>&
...
@@ -428,6 +429,7 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>&
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
}
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx_fp8
::
hip_f8
<
T
>
)
...
@@ -602,5 +604,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
...
@@ -602,5 +604,7 @@ struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
}
// namespace std
}
// namespace std
// =================================================================================================
// =================================================================================================
#if defined(__clang__) and !defined(__GNUC__)
#pragma clang diagnostic pop
#pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
3b0c7dc8
...
@@ -22,11 +22,12 @@
...
@@ -22,11 +22,12 @@
#ifndef MIGRAPHX_HIP_FP8_IMPL_HPP
#ifndef MIGRAPHX_HIP_FP8_IMPL_HPP
#define MIGRAPHX_HIP_FP8_IMPL_HPP
#define MIGRAPHX_HIP_FP8_IMPL_HPP
#if !defined(__GNUC__)
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace
migraphx_hip_f8_impl
{
namespace
migraphx_hip_f8_impl
{
namespace
detail
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
template
<
bool
B
,
class
T
,
class
F
>
...
@@ -55,17 +56,25 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
...
@@ -55,17 +56,25 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
if
(
sizeof
(
T
)
==
4
)
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
#if defined(__GNUC__) and !defined(__clang__)
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
CONST_FOLD
(
*
reinterpret_cast
<
uint32_t
*>
(
&
_x
));
else
else
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
x
=
CONST_FOLD
(
*
reinterpret_cast
<
uint16_t
*>
(
&
_x
));
#else
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
__builtin_bit_cast
(
uint32_t
,
_x
);
else
x
=
__builtin_bit_cast
(
uint16_t
,
_x
);
#endif
uint32_t
head
,
mantissa
;
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
int
exponent
,
bias
;
uint32_t
sign
;
uint32_t
sign
;
if
(
sizeof
(
T
)
==
4
)
if
constexpr
(
sizeof
(
T
)
==
4
)
{
{
head
=
x
&
0xFF800000
;
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
mantissa
=
x
&
0x7FFFFF
;
...
@@ -233,14 +242,22 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
...
@@ -233,14 +242,22 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
constexpr
int
wmo
=
23
;
constexpr
int
wmo
=
23
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
const
uint32_t
ifInf
=
0x7F800000
;
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
#if defined(__GNUC__) and !defined(__clang__)
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fInf
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifInf
)));
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNegInf
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNegInf
)));
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
fNaN
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNaN
)));
fNeg0
=
CONST_FOLD
(
*
(
reinterpret_cast
<
float
*>
(
&
ifNeg0
)));
#else
// TODO: need to change T for half but right now it would never be called with half
fInf
=
__builtin_bit_cast
(
float
,
ifInf
);
fNegInf
=
__builtin_bit_cast
(
float
,
ifNegInf
);
fNaN
=
__builtin_bit_cast
(
float
,
ifNaN
);
fNeg0
=
__builtin_bit_cast
(
float
,
ifNeg0
);
#endif
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
...
@@ -288,9 +305,15 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
...
@@ -288,9 +305,15 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
#if defined(__GNUC__) and !defined(__clang__)
return
CONST_FOLD
(
*
reinterpret_cast
<
T
*>
(
&
retval
));
#else
return
__builtin_bit_cast
(
T
,
retval
);
#endif
}
}
}
// namespace migraphx_hip_f8_impl
}
// namespace migraphx_hip_f8_impl
#if !defined(__GNUC__)
#pragma clang diagnostic pop
#pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_HIP_FP8_IMPL_HPP
#endif // MIGRAPHX_HIP_FP8_IMPL_HPP
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment