Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8319e01f
Commit
8319e01f
authored
Nov 14, 2023
by
Umang Yadav
Browse files
Fix tidy
parent
ab653aff
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
141 additions
and
134 deletions
+141
-134
src/include/migraphx/bit_cast.hpp
src/include/migraphx/bit_cast.hpp
+8
-0
src/include/migraphx/float8.hpp
src/include/migraphx/float8.hpp
+16
-13
src/include/migraphx/float8_impl.hpp
src/include/migraphx/float8_impl.hpp
+98
-121
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-0
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+9
-0
test/fp8e5m2.cpp
test/fp8e5m2.cpp
+9
-0
No files found.
src/include/migraphx/bit_cast.hpp
View file @
8319e01f
...
@@ -21,8 +21,13 @@
...
@@ -21,8 +21,13 @@
* ************************************************************************ */
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace
migraphx
{
namespace
migraphx
{
...
@@ -39,4 +44,7 @@ inline constexpr To bit_cast(From fr) noexcept
...
@@ -39,4 +44,7 @@ inline constexpr To bit_cast(From fr) noexcept
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
src/include/migraphx/float8.hpp
View file @
8319e01f
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// We are clipping/saturation in down conversion by default. Unclipped version is not tested and
// shouldn't be used without having enough tests.
// shouldn't be used without having enough tests.
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast
// NOLINTNEXTLINE
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#include <cmath>
#include <cmath>
...
@@ -173,6 +174,7 @@ struct float8
...
@@ -173,6 +174,7 @@ struct float8
}
}
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& operator unary_op(const float8& rhs) \
constexpr float8& operator unary_op(const float8& rhs) \
{ \
{ \
...
@@ -192,8 +194,8 @@ struct float8
...
@@ -192,8 +194,8 @@ struct float8
MIGRAPHX_FP8_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_UNARY_OP
(
/=
,
/
)
MIGRAPHX_FP8_UNARY_OP
(
/=
,
/
)
inline
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
=
default
;
inline
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
constexpr
float8
&
operator
=
(
float
rhs
)
inline
constexpr
float8
&
operator
=
(
float
rhs
)
{
{
...
@@ -203,11 +205,9 @@ struct float8
...
@@ -203,11 +205,9 @@ struct float8
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
{
if
(
rhs
.
is_zero
()
and
this
->
is_zero
())
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
true
;
else
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
return
false
;
else
if
(
this
->
data
==
rhs
.
data
)
else
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
(
this
->
data
==
rhs
.
data
)
)
return
true
;
return
true
;
return
false
;
return
false
;
}
}
...
@@ -260,7 +260,7 @@ MIGRAPHX_FP8_BINARY_OP(!=, bool)
...
@@ -260,7 +260,7 @@ MIGRAPHX_FP8_BINARY_OP(!=, bool)
template
<
migraphx
::
fp8
::
f8_type
T
>
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
inline
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
{
{
v
.
data
=
v
.
data
&
0x7f
;
v
.
data
=
v
.
data
&
0x7f
;
// NOLINT
return
v
;
return
v
;
}
}
...
@@ -277,7 +277,7 @@ class numeric_limits<fp8e4m3fnuz>
...
@@ -277,7 +277,7 @@ class numeric_limits<fp8e4m3fnuz>
public:
public:
static
constexpr
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
fp8e4m3fnuz
quiet_NaN
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
quiet_NaN
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
max
()
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
max
()
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
...
@@ -294,7 +294,7 @@ class numeric_limits<fp8e4m3fn>
...
@@ -294,7 +294,7 @@ class numeric_limits<fp8e4m3fn>
public:
public:
static
constexpr
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
fp8e4m3fn
quiet_NaN
()
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
quiet_NaN
()
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
...
@@ -312,7 +312,10 @@ class numeric_limits<fp8e5m2fnuz>
...
@@ -312,7 +312,10 @@ class numeric_limits<fp8e5m2fnuz>
public:
public:
static
constexpr
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
quiet_NaN
()
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
max
()
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
max
()
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
...
@@ -328,7 +331,7 @@ class numeric_limits<fp8e5m2>
...
@@ -328,7 +331,7 @@ class numeric_limits<fp8e5m2>
public:
public:
static
constexpr
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
static
constexpr
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
fp8e5m2
quiet_NaN
()
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
static
constexpr
fp8e5m2
quiet_NaN
()
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
// NOLINT
static
constexpr
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
static
constexpr
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
...
@@ -345,8 +348,8 @@ class numeric_limits<fp8e5m2>
...
@@ -345,8 +348,8 @@ class numeric_limits<fp8e5m2>
// =================================================================================================
// =================================================================================================
// define numeric limits for the new data type
// define numeric limits for the new data type
// NOLINTBEGIN
namespace
std
{
namespace
std
{
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
...
@@ -372,8 +375,8 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
...
@@ -372,8 +375,8 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2fnuz
)
}
// namespace std
}
// namespace std
// NOLINTEND
// =================================================================================================
// =================================================================================================
#if defined(__clang__)
#if defined(__clang__)
#pragma clang diagnostic pop
#pragma clang diagnostic pop
...
...
src/include/migraphx/float8_impl.hpp
View file @
8319e01f
...
@@ -30,111 +30,91 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -30,111 +30,91 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
fp8
{
namespace
fp8
{
namespace
impl
{
namespace
impl
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
// NOLINTBEGIN
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
template
<
uint32_t
Wm
,
uint32_t
We
,
typename
T
,
bool
NegativeZeroNan
,
bool
Clip
>
constexpr
uint8_t
cast_to_f8
(
T
f_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
{
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// half is not supported for now
// half is not supported for now
constexpr
bool
is_half
=
false
;
constexpr
bool
is_half
=
false
;
static_assert
(
w
m
+
w
e
==
7
,
"
w
m+
w
e==7"
);
static_assert
(
W
m
+
W
e
==
7
,
"
W
m+
W
e==7"
);
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
u
int
32_t
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
if
constexpr
(
sizeof
(
T
)
==
4
)
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
f
_x
);
else
else
x
=
migraphx
::
bit_cast
<
uint16_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint16_t
>
(
f_x
);
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
uint32_t
head
=
0
;
uint32_t
mantissa
=
0
;
int
exponent
=
0
;
uint32_t
bias
=
0
;
uint32_t
sign
=
0
;
if
constexpr
(
sizeof
(
T
)
==
4
)
if
constexpr
(
sizeof
(
T
)
==
4
)
{
{
head
=
x
&
0xFF800000
;
head
=
x
&
0xFF800000
;
// NOLINT
mantissa
=
x
&
0x7FFFFF
;
mantissa
=
x
&
0x7FFFFF
;
// NOLINT
exponent
=
(
head
>>
23
)
&
0xFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
// NOLINT
sign
=
head
>>
31
;
sign
=
head
>>
31
;
// NOLINT
bias
=
127
;
bias
=
127
;
}
}
else
else
{
{
head
=
x
&
0xFC00
;
head
=
x
&
0xFC00
;
// NOLINT
mantissa
=
x
&
0x3FF
;
mantissa
=
x
&
0x3FF
;
// NOLINT
exponent
=
(
head
>>
10
)
&
0x1F
;
exponent
=
(
head
>>
10
)
&
0x1F
;
// NOLINT
sign
=
head
>>
15
;
sign
=
head
>>
15
;
// NOLINT
bias
=
15
;
bias
=
15
;
}
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
w
e
)
-
1
)
<<
w
m
);
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
W
e
)
-
1
)
<<
W
m
);
// NOLINT
uint32_t
signed_all_ones
=
(
sign
<<
7
)
+
((((
1
<<
w
e
)
-
1
)
<<
w
m
)
+
((
1
<<
w
m
)
-
1
));
uint32_t
signed_all_ones
=
(
sign
<<
7
)
+
((((
1
<<
W
e
)
-
1
)
<<
W
m
)
+
((
1
<<
W
m
)
-
1
));
// NOLINT
// Calcualte maximum singed value FLT_MAX, FLT_MIN
// Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t
signed_max
=
signed_all_ones
;
uint32_t
signed_max
=
signed_all_ones
;
if
(
not
negative_zero_nan
)
if
(
not
NegativeZeroNan
)
{
signed_max
=
(
Wm
==
2
)
?
(
signed_max
-
4
)
:
(
signed_max
-
1
);
signed_max
=
(
wm
==
2
)
?
(
signed_max
-
4
)
:
(
signed_max
-
1
);
}
// Deal with inf and NaNs
// Deal with inf and NaNs
if
(
n
egative
_z
ero
_n
an
)
// For the FNUZ cases, it is simple just return NaNs
if
(
N
egative
Z
ero
N
an
)
// For the FNUZ cases, it is simple just return NaNs
{
{
if
(
sizeof
(
T
)
==
4
)
if
((
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
or
// NOLINT
{
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
)))
// NOLINT
if
((
x
&
0x7F800000
)
==
0x7F800000
)
return
0x80
;
return
0x80
;
}
else
{
if
((
x
&
0x7C00
)
==
0x7C00
)
return
0x80
;
}
}
}
else
else
{
{
// calculate most common NaN mantissa for FP8, which is all Ones in binary
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t
nan_mantissa
=
1
;
uint32_t
nan_mantissa
=
1
;
for
(
auto
i
=
1
;
i
<
w
m
;
++
i
)
for
(
auto
i
=
1
;
i
<
W
m
;
++
i
)
{
{
nan_mantissa
|=
(
nan_mantissa
<<
1
);
nan_mantissa
|=
(
nan_mantissa
<<
1
);
// NOLINT
}
}
if
((
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
or
if
((
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
or
// NOLINT
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
)))
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
)))
// NOLINT
{
{
// infinity
// infinity
if
(
mantissa
==
0
)
if
(
mantissa
==
0
)
{
{
if
(
sign
==
0
)
if
(
sign
==
0
)
{
return
(
Wm
==
2
)
?
0x7B
:
0x7E
;
return
(
wm
==
2
)
?
0x7B
:
0x7E
;
}
else
else
{
return
(
Wm
==
2
)
?
0xFB
:
0xFE
;
return
(
wm
==
2
)
?
0xFB
:
0xFE
;
}
}
}
else
else
// NaNs
{
// NaNs
return
signed_inf
+
nan_mantissa
;
return
signed_inf
+
nan_mantissa
;
}
}
}
}
}
// handle positive zero
// handle positive zero
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
// handle negative zero
// handle negative zero
if
((
sizeof
(
T
)
==
4
and
x
==
0x80000000
)
or
(
sizeof
(
T
)
==
2
and
x
==
0x8000
))
else
if
((
sizeof
(
T
)
==
4
and
x
==
0x80000000
)
or
(
sizeof
(
T
)
==
2
and
x
==
0x8000
))
{
{
if
(
negative_zero_nan
)
// For FNUZ types neg zero is just positive zero
return
NegativeZeroNan
?
0
:
0x80
;
// For FNUZ types neg zero is just positive zero
{
return
0
;
}
else
{
return
0x80
;
}
}
}
/* First need to check if it is normal or denorm as there is a difference of implict 1
/* First need to check if it is normal or denorm as there is a difference of implict 1
...
@@ -144,13 +124,15 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
...
@@ -144,13 +124,15 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
exponent and mantissa again*/
exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const
int
f8_bias
=
(
1
<<
(
w
e
-
1
))
-
1
+
(
n
egative
_z
ero
_n
an
?
1
:
0
);
const
int
f8_bias
=
(
1
<<
(
W
e
-
1
u
))
-
1
+
(
N
egative
Z
ero
N
an
?
1
:
0
);
// NOLINT
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding
f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/
the difference needs to be adjusted and mantissa shifted*/
int
act_exponent
,
f8_exponent
,
exponent_diff
;
int
act_exponent
=
0
;
int
f8_exponent
=
0
;
int
exponent_diff
=
0
;
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
{
// fp32/fp16 is in denormal.
...
@@ -182,11 +164,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
...
@@ -182,11 +164,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
0
;
// exponent_diff=0 does not mean there is no difference for this case,
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
// act_exponent could be larger. Just that it does not need shift mantissa
}
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
mantissa
+=
(
1
u
<<
mfmt
);
// Add the implicit 1 into mantissa
}
}
// NOLINTNEXTLINE
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
w
m
+
exponent_diff
))
-
1
))
==
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
W
m
+
exponent_diff
))
-
1
))
==
(
1
<<
(
mfmt
-
w
m
+
exponent_diff
-
1
));
(
1
<<
(
mfmt
-
W
m
+
exponent_diff
-
1
));
// NOLINT
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
...
@@ -194,64 +176,58 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
...
@@ -194,64 +176,58 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0)
*/
*/
if
(
exponent_diff
>
0
)
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
mantissa
>>=
exponent_diff
;
// NOLINT
else
if
(
exponent_diff
==
-
1
)
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
mantissa
<<=
-
exponent_diff
;
// NOLINT
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// NOLINT
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent
=
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
w
m
))
-
1
;
uint32_t
drop_mask
=
(
1
u
<<
(
mfmt
-
W
m
))
-
1
;
// NOLINT
bool
odd
=
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit that is not truncated is 1
mantissa
&
(
1u
<<
(
mfmt
-
Wm
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
// NOLINT
drop_mask
;
// NOLINT
// Now we deal with overflow
// Now we deal with overflow
if
(
f8_exponent
==
0
)
if
(
f8_exponent
==
0
and
((
1
<<
mfmt
)
&
mantissa
))
// NOLINT
{
{
if
((
1
<<
mfmt
)
&
mantissa
)
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
}
else
else
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
// NOLINT
{
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
mantissa
>>=
1
;
// NOLINT
{
f8_exponent
++
;
mantissa
>>=
1
;
f8_exponent
++
;
}
}
}
mantissa
>>=
(
mfmt
-
w
m
);
mantissa
>>=
(
mfmt
-
W
m
);
// NOLINT
// above range: quantize to maximum possible float of the same sign
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
w
e
)
-
(
n
egative
_z
ero
_n
an
?
1
:
2
);
const
int
max_exp
=
(
1
<<
W
e
)
-
(
N
egative
Z
ero
N
an
?
1
:
2
);
// NOLINT
if
(
f8_exponent
>
max_exp
)
if
(
f8_exponent
>
max_exp
)
{
{
if
(
clip
)
if
(
Clip
)
{
return
signed_max
;
return
signed_max
;
}
else
else
{
{
// https://onnx.ai/onnx/technical/float8.html#cast
// https://onnx.ai/onnx/technical/float8.html#cast
if
(
n
egative
_z
ero
_n
an
)
if
(
N
egative
Z
ero
N
an
)
return
0x80
;
return
0x80
;
else
else
return
(
w
m
==
2
)
?
signed_inf
:
signed_all_ones
;
return
(
W
m
==
2
)
?
signed_inf
:
signed_all_ones
;
}
}
}
}
if
(
f8_exponent
==
0
and
mantissa
==
0
)
if
(
f8_exponent
==
0
and
mantissa
==
0
)
return
n
egative
_z
ero
_n
an
?
0
:
(
sign
<<
7
);
return
N
egative
Z
ero
N
an
?
0
:
(
sign
<<
7
);
// NOLINT
mantissa
&=
(
1
<<
w
m
)
-
1
;
mantissa
&=
(
1
<<
W
m
)
-
1
;
// NOLINT
return
(
sign
<<
7
)
|
(
f8_exponent
<<
w
m
)
|
mantissa
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
W
m
)
|
mantissa
;
// NOLINT
}
}
// NOLINTEND
template
<
int
w
m
,
int
w
e
,
typename
T
,
bool
n
egative
_z
ero
_n
an
>
template
<
u
int
32_t
W
m
,
u
int
32_t
W
e
,
typename
T
,
bool
N
egative
Z
ero
N
an
>
constexpr
T
cast_from_f8
(
uint8_t
x
)
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
{
// half is not supported for now
// half is not supported for now
...
@@ -261,69 +237,70 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -261,69 +237,70 @@ constexpr T cast_from_f8(uint8_t x)
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
// NOLINTNEXTLINE
T
f
I
nf
,
f
NegI
nf
,
f
NaN
,
f
N
eg0
;
T
f
_i
nf
,
f
_neg_i
nf
,
f
_nan
,
f
_n
eg0
;
if
constexpr
(
is_float
)
if
constexpr
(
is_float
)
{
{
const
uint32_t
if
I
nf
=
0x7F800000
;
const
uint32_t
if
_i
nf
=
0x7F800000
;
const
uint32_t
if
NegI
nf
=
0xFF800000
;
const
uint32_t
if
_neg_i
nf
=
0xFF800000
;
const
uint32_t
if
NaN
=
0x7F800001
;
const
uint32_t
if
_nan
=
0x7F800001
;
const
uint32_t
if
N
eg0
=
0x80000000
;
const
uint32_t
if
_n
eg0
=
0x80000000
;
f
I
nf
=
migraphx
::
bit_cast
<
float
>
(
if
I
nf
);
f
_i
nf
=
migraphx
::
bit_cast
<
float
>
(
if
_i
nf
);
f
NegI
nf
=
migraphx
::
bit_cast
<
float
>
(
if
NegI
nf
);
f
_neg_i
nf
=
migraphx
::
bit_cast
<
float
>
(
if
_neg_i
nf
);
f
NaN
=
migraphx
::
bit_cast
<
float
>
(
if
NaN
);
f
_nan
=
migraphx
::
bit_cast
<
float
>
(
if
_nan
);
f
N
eg0
=
migraphx
::
bit_cast
<
float
>
(
if
N
eg0
);
f
_n
eg0
=
migraphx
::
bit_cast
<
float
>
(
if
_n
eg0
);
}
}
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
uint32_t
sign
=
x
>>
7
;
uint32_t
sign
=
x
>>
7
;
// NOLINT
uint32_t
mantissa
=
x
&
((
1
<<
w
m
)
-
1
);
uint32_t
mantissa
=
x
&
((
1
<<
W
m
)
-
1
);
// NOLINT
int
exponent
=
(
x
&
0x7F
)
>>
w
m
;
int
exponent
=
(
x
&
0x7F
)
>>
W
m
;
// NOLINT
if
(
n
egative
_z
ero
_n
an
)
if
(
N
egative
Z
ero
N
an
)
{
{
if
(
x
==
0x80
)
if
(
x
==
0x80
)
return
f
NaN
;
return
f
_nan
;
}
}
else
else
{
{
if
(
x
==
0x80
)
if
(
x
==
0x80
)
return
f
N
eg0
;
return
f
_n
eg0
;
if
(
exponent
==
((
1
<<
w
e
)
-
1
)
and
w
m
==
2
)
if
(
exponent
==
((
1
<<
W
e
)
-
1
)
and
W
m
==
2
)
// NOLINT
return
(
mantissa
==
0
)
?
(
sign
?
f
NegI
nf
:
f
I
nf
)
:
f
NaN
;
return
(
mantissa
==
0
)
?
(
sign
?
f
_neg_i
nf
:
f
_i
nf
)
:
f
_nan
;
else
if
(
w
m
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
else
if
(
W
m
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
f
NaN
;
return
f
_nan
;
}
}
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
We
-
1
))
+
1
-
(
NegativeZeroNan
?
1
:
0
);
// NOLINT
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
w
m
);
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
W
m
);
mantissa
<<=
sh
;
mantissa
<<=
sh
;
// NOLINT
exponent
+=
1
-
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
w
m
)
-
1
);
mantissa
&=
((
1
<<
W
m
)
-
1
);
// NOLINT
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
w
m
;
mantissa
<<=
wmo
-
W
m
;
// NOLINT
// subnormal output (occurs when T=half,
w
e=5, negative_zero_nan=true)
// subnormal output (occurs when T=half,
W
e=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
wmo
;
mantissa
|=
1
<<
wmo
;
// NOLINT
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
// NOLINT
exponent
=
0
;
exponent
=
0
;
}
}
if
(
sizeof
(
T
)
==
2
)
if
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
// NOLINT
else
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
// NOLINT
return
migraphx
::
bit_cast
<
T
>
(
retval
);
return
migraphx
::
bit_cast
<
T
>
(
retval
);
}
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
8319e01f
...
@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
tuple_type
:
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
case
shape
::
uint16_type
:
...
...
test/fp8e4m3fn.cpp
View file @
8319e01f
...
@@ -134,6 +134,15 @@ TEST_CASE(test_negative_zero)
...
@@ -134,6 +134,15 @@ TEST_CASE(test_negative_zero)
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
}
}
TEST_CASE
(
test_pos_zero_eq_neg_zero
)
{
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_pzero
(
pzero
);
EXPECT
(
fp8_nzero
==
fp8_pzero
);
}
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
...
...
test/fp8e5m2.cpp
View file @
8319e01f
...
@@ -331,6 +331,15 @@ TEST_CASE(test_negative_zero)
...
@@ -331,6 +331,15 @@ TEST_CASE(test_negative_zero)
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
}
}
TEST_CASE
(
test_pos_zero_eq_neg_zero
)
{
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_pzero
(
pzero
);
EXPECT
(
fp8_nzero
==
fp8_pzero
);
}
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment