Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4e9d51f0
"benchmarks/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "e0c6f556e85053059c74ab6b5cee396baf3b4316"
Commit
4e9d51f0
authored
Nov 10, 2023
by
Umang Yadav
Browse files
Working FNUZ and FN
parent
d9f11e31
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
335 additions
and
82 deletions
+335
-82
src/include/migraphx/migraphx_f8_impl.hpp
src/include/migraphx/migraphx_f8_impl.hpp
+60
-10
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+71
-62
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+202
-0
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+2
-10
No files found.
src/include/migraphx/migraphx_f8_impl.hpp
View file @
4e9d51f0
...
@@ -86,6 +86,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
...
@@ -86,6 +86,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
}
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
uint32_t
signed_max
=
(
sign
<<
7
)
+
((((
1
<<
we
)
-
1
)
<<
wm
)
+
((
1
<<
wm
)
-
1
));
if
(
not
negative_zero_nan
)
{
signed_max
=
(
wm
==
2
)
?
(
signed_max
-
4
)
:
(
signed_max
-
1
);
}
// Deal with inf and NaNs
// Deal with inf and NaNs
if
(
negative_zero_nan
)
if
(
negative_zero_nan
)
...
@@ -103,15 +108,50 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
...
@@ -103,15 +108,50 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
}
}
else
else
{
{
if
(
sizeof
(
T
)
==
4
)
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t
nan_mantissa
=
1
;
for
(
auto
i
=
1
;
i
<
wm
;
++
i
)
{
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
nan_mantissa
|=
(
nan_mantissa
<<
1
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
// cppcheck-suppress InvertedLogic
}
// TODO: abstract duplicate branches
if
(
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
{
// infinity
if
(
mantissa
==
0
)
{
if
(
sign
==
0
)
{
return
(
wm
==
2
)
?
0x7B
:
0x7E
;
}
}
else
else
{
{
if
((
x
&
0x7C00
)
==
0x7C00
)
return
(
wm
==
2
)
?
0xFB
:
0xFE
;
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
// cppcheck-suppress InvertedLogic
}
}
else
{
// NaNs
return
signed_inf
+
nan_mantissa
;
}
}
else
if
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
))
{
// infinity
if
(
mantissa
==
0
)
{
if
(
sign
==
0
)
{
return
(
wm
==
2
)
?
0x7B
:
0x7E
;
}
else
{
return
(
wm
==
2
)
?
0xFB
:
0xFE
;
}
}
else
{
// NaNs
return
signed_inf
+
nan_mantissa
;
}
}
}
}
}
// handle positive zero
// handle positive zero
...
@@ -222,16 +262,24 @@ this case, the fp16 mantissa should be shift left by 1 */
...
@@ -222,16 +262,24 @@ this case, the fp16 mantissa should be shift left by 1 */
// above range: quantize to maximum possible float of the same sign
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
// TODO: this is ugly, need better way to handle out of range values
if
(
f8_exponent
>
max_exp
)
if
(
f8_exponent
>
max_exp
)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
wm
)
-
1
;
return
signed_max
;
f8_exponent
=
max_exp
;
}
}
else
else
{
{
return
signed_inf
;
if
(
negative_zero_nan
)
{
return
0x80
;
}
else
{
uint32_t
tmp_signed_max
=
(
sign
<<
7
)
+
((((
1
<<
we
)
-
1
)
<<
wm
)
+
((
1
<<
wm
)
-
1
));
return
(
wm
==
2
)
?
signed_inf
:
tmp_signed_max
;
}
}
}
}
}
...
@@ -273,8 +321,10 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -273,8 +321,10 @@ constexpr T cast_from_f8(uint8_t x)
{
{
if
(
x
==
0x80
)
if
(
x
==
0x80
)
return
fNeg0
;
return
fNeg0
;
if
(
exponent
==
((
1
<<
we
)
-
1
))
if
(
exponent
==
((
1
<<
we
)
-
1
)
and
wm
==
2
)
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
else
if
(
wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
fNaN
;
}
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
...
...
src/include/migraphx/migraphx_float8.hpp
View file @
4e9d51f0
...
@@ -79,7 +79,7 @@ struct float8
...
@@ -79,7 +79,7 @@ struct float8
// default constructor
// default constructor
constexpr
float8
()
=
default
;
constexpr
float8
()
=
default
;
// default copy constructor
// default copy constructor
constexpr
float8
(
const
float8
<
T
>
&
y
)
=
default
;
constexpr
float8
(
const
float8
&
y
)
=
default
;
struct
from_bits_t
struct
from_bits_t
{
{
};
};
...
@@ -149,15 +149,12 @@ struct float8
...
@@ -149,15 +149,12 @@ struct float8
{
{
if
(
T
==
migraphx_fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx_fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7
d
)
or
(
data
==
0x7
e
)
or
(
data
==
0x7
f
)
or
(
data
==
0x
fd
)
or
return
(
data
==
0x7
D
)
or
(
data
==
0x7
E
)
or
(
data
==
0x7
F
)
or
(
data
==
0x
FD
)
or
(
data
==
0x
fe
)
or
(
data
==
0x
ff
);
(
data
==
0x
FE
)
or
(
data
==
0x
FF
);
}
}
else
else
{
{
return
(
data
==
0x79
)
or
(
data
==
0x7a
)
or
(
data
==
0x7b
)
or
(
data
==
0x7c
)
or
return
(
data
==
0x7F
)
or
(
data
==
0xFF
);
(
data
==
0x7d
)
or
(
data
==
0x7e
)
or
(
data
==
0x7f
)
or
(
data
==
0xf9
)
or
(
data
==
0xfa
)
or
(
data
==
0xfb
)
or
(
data
==
0xfc
)
or
(
data
==
0xfd
)
or
(
data
==
0xfe
)
or
(
data
==
0xff
);
}
}
}
}
}
}
...
@@ -172,11 +169,12 @@ struct float8
...
@@ -172,11 +169,12 @@ struct float8
{
{
if
(
T
==
migraphx_fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx_fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7
c
)
or
(
data
==
0x
fc
);
return
(
data
==
0x7
C
)
or
(
data
==
0x
FC
);
}
}
else
else
{
{
return
(
data
==
0x78
)
or
(
data
==
0xf8
);
// no infinities in e4m3fn, represent them as NaNs
return
(
data
==
0x7F
)
or
(
data
==
0xFF
);
}
}
}
}
}
}
...
@@ -211,12 +209,12 @@ struct float8
...
@@ -211,12 +209,12 @@ struct float8
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
inline
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
{
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
if
(
rhs
.
is_zero
()
and
this
->
is_zero
())
(
fabs
(
rhs
-
*
this
)
<
migraphx_fp8
::
numeric_limits
<
float8
<
T
,
FNUZ
>>::
epsilon
()))
return
true
;
return
true
;
else
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
else
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
return
false
;
else
if
(
this
->
data
==
rhs
.
data
)
return
true
;
return
false
;
return
false
;
}
}
...
@@ -272,8 +270,6 @@ inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v)
...
@@ -272,8 +270,6 @@ inline migraphx_fp8::float8<T> fabs(migraphx_fp8::float8<T> v)
}
}
// https://onnx.ai/onnx/technical/float8.html
// https://onnx.ai/onnx/technical/float8.html
// these types are not exactly same as GraphCore's FNUZ types. GraphCore's FNUZ types assumes
// exponent bias of 8 and 16 for the FNUZ types, ONNX spec
using
fp8e4m3fn
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e4m3fn
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx_fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx_fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx_fp8
::
f8_type
::
fp8
,
true
>
;
...
@@ -282,6 +278,8 @@ using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
...
@@ -282,6 +278,8 @@ using fp8e5m2fnuz = float8<migraphx_fp8::f8_type::bf8, true>;
template
<
>
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
class
numeric_limits
<
fp8e4m3fnuz
>
{
{
static
constexpr
bool
has_infinity
=
false
;
public:
public:
static
constexpr
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
...
@@ -292,13 +290,30 @@ class numeric_limits<fp8e4m3fnuz>
...
@@ -292,13 +290,30 @@ class numeric_limits<fp8e4m3fnuz>
static
constexpr
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
lowest
()
{
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
fp8e4m3fnuz
lowest
()
{
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
fp8e4m3fn
>
{
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
fp8e4m3fnuz
infinity
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
public:
static
constexpr
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
quiet_NaN
()
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
fp8e4m3fn
lowest
()
{
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
}
};
};
template
<
>
template
<
>
class
numeric_limits
<
fp8e5m2fnuz
>
class
numeric_limits
<
fp8e5m2fnuz
>
{
{
static
constexpr
bool
has_infinity
=
false
;
public:
public:
static
constexpr
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
...
@@ -310,62 +325,56 @@ class numeric_limits<fp8e5m2fnuz>
...
@@ -310,62 +325,56 @@ class numeric_limits<fp8e5m2fnuz>
static
constexpr
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
lowest
()
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
lowest
()
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
fp8e5m2fnuz
infinity
()
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
};
};
}
// namespace migraphx_fp8
// =================================================================================================
// define numeric limits for the new data type
namespace
std
{
inline
bool
isfinite
(
migraphx_fp8
::
fp8e4m3fnuz
x
)
// NOLINT
{
return
x
.
is_inf
();
}
inline
bool
isfinite
(
migraphx_fp8
::
fp8e5m2fnuz
x
)
// NOLINT
{
return
x
.
is_inf
();
}
inline
bool
isnan
(
migraphx_fp8
::
fp8e4m3fnuz
x
)
// NOLINT
{
return
x
.
is_nan
();
}
inline
bool
isnan
(
migraphx_fp8
::
fp8e5m2fnuz
x
)
// NOLINT
{
return
x
.
is_nan
();
}
template
<
>
template
<
>
class
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>
class
numeric_limits
<
fp8e5m2
>
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>
{
{
};
public:
static
constexpr
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
fp8e5m2
quiet_NaN
()
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
template
<
>
static
constexpr
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
class
numeric_limits
<
migraphx_fp8
::
fp8e5m2fnuz
>
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
:
public
migraphx_fp8
::
numeric_limits
<
migraphx_fp8
::
fp8e5m2fnuz
>
// this distinction. For the floating points we would end up using lowest most of the times.
{
static
constexpr
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
};
template
<
class
T
>
static
constexpr
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
struct
common_type
<
migraphx_fp8
::
fp8e4m3fnuz
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
// 7C and FC both are infinity
{
static
constexpr
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
};
}
// namespace migraphx_fp8
template
<
class
T
>
// =================================================================================================
struct
common_type
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
// define numeric limits for the new data type
{
namespace
std
{
};
template
<
>
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
struct
common_type
<
migraphx_fp8
::
fp8e4m3fnuz
,
migraphx_fp8
::
fp8e4m3fnuz
>
inline bool isfinite(T x) { return x.is_inf(); } \
{
inline bool isnan(T x) { return x.is_nan(); } \
using
type
=
float
;
template <> \
};
class numeric_limits<T> : public migraphx_fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx_fp8
::
fp8e4m3fn
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx_fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx_fp8
::
fp8e5m2fnuz
)
}
// namespace std
}
// namespace std
// =================================================================================================
// =================================================================================================
...
...
test/fp8e4m3fn.cpp
0 → 100644
View file @
4e9d51f0
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/migraphx_float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float
fp8e4m3fn_to_fp32_value
(
uint8_t
input
)
{
constexpr
std
::
array
<
float
,
256
>
e4m3fnuz_lut
=
{
0.0
,
0.001953125
,
0.00390625
,
0.005859375
,
0.0078125
,
0.009765625
,
0.01171875
,
0.013671875
,
0.015625
,
0.017578125
,
0.01953125
,
0.021484375
,
0.0234375
,
0.025390625
,
0.02734375
,
0.029296875
,
0.03125
,
0.03515625
,
0.0390625
,
0.04296875
,
0.046875
,
0.05078125
,
0.0546875
,
0.05859375
,
0.0625
,
0.0703125
,
0.078125
,
0.0859375
,
0.09375
,
0.1015625
,
0.109375
,
0.1171875
,
0.125
,
0.140625
,
0.15625
,
0.171875
,
0.1875
,
0.203125
,
0.21875
,
0.234375
,
0.25
,
0.28125
,
0.3125
,
0.34375
,
0.375
,
0.40625
,
0.4375
,
0.46875
,
0.5
,
0.5625
,
0.625
,
0.6875
,
0.75
,
0.8125
,
0.875
,
0.9375
,
1.0
,
1.125
,
1.25
,
1.375
,
1.5
,
1.625
,
1.75
,
1.875
,
2.0
,
2.25
,
2.5
,
2.75
,
3.0
,
3.25
,
3.5
,
3.75
,
4.0
,
4.5
,
5.0
,
5.5
,
6.0
,
6.5
,
7.0
,
7.5
,
8.0
,
9.0
,
10.0
,
11.0
,
12.0
,
13.0
,
14.0
,
15.0
,
16.0
,
18.0
,
20.0
,
22.0
,
24.0
,
26.0
,
28.0
,
30.0
,
32.0
,
36.0
,
40.0
,
44.0
,
48.0
,
52.0
,
56.0
,
60.0
,
64.0
,
72.0
,
80.0
,
88.0
,
96.0
,
104.0
,
112.0
,
120.0
,
128.0
,
144.0
,
160.0
,
176.0
,
192.0
,
208.0
,
224.0
,
240.0
,
256.0
,
288.0
,
320.0
,
352.0
,
384.0
,
416.0
,
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0
,
-
0.001953125
,
-
0.00390625
,
-
0.005859375
,
-
0.0078125
,
-
0.009765625
,
-
0.01171875
,
-
0.013671875
,
-
0.015625
,
-
0.017578125
,
-
0.01953125
,
-
0.021484375
,
-
0.0234375
,
-
0.025390625
,
-
0.02734375
,
-
0.029296875
,
-
0.03125
,
-
0.03515625
,
-
0.0390625
,
-
0.04296875
,
-
0.046875
,
-
0.05078125
,
-
0.0546875
,
-
0.05859375
,
-
0.0625
,
-
0.0703125
,
-
0.078125
,
-
0.0859375
,
-
0.09375
,
-
0.1015625
,
-
0.109375
,
-
0.1171875
,
-
0.125
,
-
0.140625
,
-
0.15625
,
-
0.171875
,
-
0.1875
,
-
0.203125
,
-
0.21875
,
-
0.234375
,
-
0.25
,
-
0.28125
,
-
0.3125
,
-
0.34375
,
-
0.375
,
-
0.40625
,
-
0.4375
,
-
0.46875
,
-
0.5
,
-
0.5625
,
-
0.625
,
-
0.6875
,
-
0.75
,
-
0.8125
,
-
0.875
,
-
0.9375
,
-
1.0
,
-
1.125
,
-
1.25
,
-
1.375
,
-
1.5
,
-
1.625
,
-
1.75
,
-
1.875
,
-
2.0
,
-
2.25
,
-
2.5
,
-
2.75
,
-
3.0
,
-
3.25
,
-
3.5
,
-
3.75
,
-
4.0
,
-
4.5
,
-
5.0
,
-
5.5
,
-
6.0
,
-
6.5
,
-
7.0
,
-
7.5
,
-
8.0
,
-
9.0
,
-
10.0
,
-
11.0
,
-
12.0
,
-
13.0
,
-
14.0
,
-
15.0
,
-
16.0
,
-
18.0
,
-
20.0
,
-
22.0
,
-
24.0
,
-
26.0
,
-
28.0
,
-
30.0
,
-
32.0
,
-
36.0
,
-
40.0
,
-
44.0
,
-
48.0
,
-
52.0
,
-
56.0
,
-
60.0
,
-
64.0
,
-
72.0
,
-
80.0
,
-
88.0
,
-
96.0
,
-
104.0
,
-
112.0
,
-
120.0
,
-
128.0
,
-
144.0
,
-
160.0
,
-
176.0
,
-
192.0
,
-
208.0
,
-
224.0
,
-
240.0
,
-
256.0
,
-
288.0
,
-
320.0
,
-
352.0
,
-
384.0
,
-
416.0
,
-
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
};
return
e4m3fnuz_lut
[
input
];
}
TEST_CASE
(
test_fp8_cast_to_float
)
{
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx_fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx_fp8
::
fp8e4m3fn
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
{
return
true
;
}
return
migraphx
::
float_equal
(
float
(
fp8_val
),
fp8e4m3fn_to_fp32_value
(
bit_val
));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
migraphx_fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
TEST_CASE
(
test_negative_zero
)
{
float
nzero
=
-
0.0
;
migraphx_fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e4m3fn
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
}
TEST_CASE
(
test_nan_1
)
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx_fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
TEST_CASE
(
test_nan_2
)
{
auto
fnan
=
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
quiet_NaN
();
migraphx_fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx_fp8
::
fp8e4m3fn
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_1
)
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx_fp8
::
fp8e4m3fn
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_infinity_2
)
{
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx_fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
lowest
()});
}
TEST_CASE
(
test_numeric_max_1
)
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx_fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_max_2
)
{
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
max
();
migraphx_fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_lowest_1
)
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx_fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_numeric_lowest_2
)
{
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
lowest
();
migraphx_fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_max_eq_lowest
)
{}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e4m3fnuz.cpp
View file @
4e9d51f0
...
@@ -176,25 +176,17 @@ TEST_CASE(test_nan_2)
...
@@ -176,25 +176,17 @@ TEST_CASE(test_nan_2)
TEST_CASE
(
test_infinity_1
)
TEST_CASE
(
test_infinity_1
)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz
// no inf in fp8e4m3fnuz
it gets clipped to Nans
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
TEST_CASE
(
test_infinity_2
)
TEST_CASE
(
test_infinity_2
)
{
// no inf in fp8e4m3fnuz, it gets converted to NaNs
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
std
::
numeric_limits
<
migraphx_fp8
::
fp8e4m3fnuz
>::
infinity
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_3
)
{
{
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz
// no inf in fp8e4m3fnuz
it gets clipped to NaNs
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx_fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment