Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ab653aff
Commit
ab653aff
authored
Nov 13, 2023
by
Umang Yadav
Browse files
Review updates
parent
183db78a
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
261 additions
and
241 deletions
+261
-241
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-1
src/include/migraphx/bit_cast.hpp
src/include/migraphx/bit_cast.hpp
+42
-0
src/include/migraphx/float8.hpp
src/include/migraphx/float8.hpp
+66
-72
src/include/migraphx/float8_impl.hpp
src/include/migraphx/float8_impl.hpp
+42
-48
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-14
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/float_equal.cpp
test/float_equal.cpp
+9
-9
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+23
-23
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+21
-21
test/fp8e5m2.cpp
test/fp8e5m2.cpp
+23
-23
test/fp8e5m2fnuz.cpp
test/fp8e5m2fnuz.cpp
+21
-21
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/api/include/migraphx/migraphx.h
View file @
ab653aff
...
@@ -45,7 +45,7 @@
...
@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#ifdef __cplusplus
#ifdef __cplusplus
...
...
src/include/migraphx/bit_cast.hpp
0 → 100644
View file @
ab653aff
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <migraphx/config.hpp>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
typename
To
,
typename
From
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
#if defined(__GNUC__) and !defined(__clang__)
return
MIGRAPHX_CONST_FOLD
(
*
reinterpret_cast
<
To
*>
(
&
fr
));
#else
return
__builtin_bit_cast
(
To
,
fr
);
#endif
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
src/include/migraphx/
migraphx_
float8.hpp
→
src/include/migraphx/float8.hpp
View file @
ab653aff
...
@@ -44,20 +44,12 @@
...
@@ -44,20 +44,12 @@
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace
migraphx_f8_impl
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
namespace
fp8
{
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
constexpr
T
cast_from_f8
(
uint8_t
x
);
}
// namespace migraphx_f8_impl
#include <migraphx/migraphx_f8_impl.hpp>
namespace
migraphx_fp8
{
enum
class
migraphx_f8_rounding_mode
enum
class
migraphx_f8_rounding_mode
{
{
...
@@ -74,7 +66,7 @@ enum class f8_type
...
@@ -74,7 +66,7 @@ enum class f8_type
template
<
typename
T
,
bool
FNUZ
=
true
>
template
<
typename
T
,
bool
FNUZ
=
true
>
class
numeric_limits
;
class
numeric_limits
;
template
<
migraphx
_
fp8
::
f8_type
T
=
migraphx
_
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
template
<
migraphx
::
fp8
::
f8_type
T
=
migraphx
::
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
struct
float8
struct
float8
{
{
uint8_t
data
=
0x00
;
uint8_t
data
=
0x00
;
...
@@ -90,43 +82,43 @@ struct float8
...
@@ -90,43 +82,43 @@ struct float8
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
explicit
constexpr
float8
(
float
v
,
explicit
constexpr
float8
(
float
v
,
migraphx
_
fp8
::
migraphx_f8_rounding_mode
rm
=
migraphx
::
fp8
::
migraphx_f8_rounding_mode
rm
=
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
standard
,
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
uint32_t
rng
=
0
)
{
{
if
constexpr
(
T
==
migraphx
_
fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
}
else
else
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
#endif // rocblas_F8_downcast_clipping}
}
}
}
}
inline
constexpr
operator
float
()
const
inline
constexpr
operator
float
()
const
{
{
if
constexpr
(
T
==
migraphx
_
fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
{
return
migraphx
_f8_
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
}
// else
return
migraphx
_f8_
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
}
inline
constexpr
bool
is_zero
()
const
inline
constexpr
bool
is_zero
()
const
...
@@ -149,7 +141,7 @@ struct float8
...
@@ -149,7 +141,7 @@ struct float8
}
}
else
else
{
{
if
(
T
==
migraphx
_
fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7D
)
or
(
data
==
0x7E
)
or
(
data
==
0x7F
)
or
(
data
==
0xFD
)
or
return
(
data
==
0x7D
)
or
(
data
==
0x7E
)
or
(
data
==
0x7F
)
or
(
data
==
0xFD
)
or
(
data
==
0xFE
)
or
(
data
==
0xFF
);
(
data
==
0xFE
)
or
(
data
==
0xFF
);
...
@@ -169,7 +161,7 @@ struct float8
...
@@ -169,7 +161,7 @@ struct float8
}
}
else
else
{
{
if
(
T
==
migraphx
_
fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7C
)
or
(
data
==
0xFC
);
return
(
data
==
0x7C
)
or
(
data
==
0xFC
);
}
}
...
@@ -236,26 +228,26 @@ struct float8
...
@@ -236,26 +228,26 @@ struct float8
};
};
// Special operator overloading
// Special operator overloading
template
<
migraphx
_
fp8
::
f8_type
T
>
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx
_
fp8
::
float8
<
T
>&
rhs
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx
::
fp8
::
float8
<
T
>&
rhs
)
{
{
return
os
<<
static_cast
<
float
>
(
rhs
);
return
os
<<
static_cast
<
float
>
(
rhs
);
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx
_
fp8::f8_type T> \
template <migraphx
::
fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx
_
fp8::float8<T>& lhs, \
inline constexpr U operator binary_op(const migraphx
::
fp8::float8<T>& lhs, \
const migraphx
_
fp8::float8<T>& rhs) \
const migraphx
::
fp8::float8<T>& rhs) \
{ \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
}
// TODO: these should return floats
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8
::
float8
<
T
>
)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects.
// effects.
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
...
@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
...
@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
template
<
migraphx
_
fp8
::
f8_type
T
>
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
migraphx
_
fp8
::
float8
<
T
>
fabs
(
migraphx
_
fp8
::
float8
<
T
>
v
)
inline
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
{
{
v
.
data
=
v
.
data
&
0x7f
;
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
return
v
;
}
}
// https://onnx.ai/onnx/technical/float8.html
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
_
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
_
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
_
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
_
fp8
::
f8_type
::
bf8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
template
<
>
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
class
numeric_limits
<
fp8e4m3fnuz
>
...
@@ -347,7 +339,9 @@ class numeric_limits<fp8e5m2>
...
@@ -347,7 +339,9 @@ class numeric_limits<fp8e5m2>
// 7C and FC both are infinity
// 7C and FC both are infinity
static
constexpr
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
static
constexpr
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
};
}
// namespace migraphx_fp8
}
// namespace fp8
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
// =================================================================================================
// =================================================================================================
// define numeric limits for the new data type
// define numeric limits for the new data type
...
@@ -357,7 +351,7 @@ namespace std {
...
@@ -357,7 +351,7 @@ namespace std {
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isnan(T x) { return x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
template <> \
class numeric_limits<T> : public migraphx
_
fp8::numeric_limits<T> \
class numeric_limits<T> : public migraphx
::
fp8::numeric_limits<T> \
{ \
{ \
}; \
}; \
template <class U> \
template <class U> \
...
@@ -374,10 +368,10 @@ namespace std {
...
@@ -374,10 +368,10 @@ namespace std {
using type = T; \
using type = T; \
};
};
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e4m3fn
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fn
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e5m2fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2fnuz
)
}
// namespace std
}
// namespace std
// =================================================================================================
// =================================================================================================
...
...
src/include/migraphx/
migraphx_f
8_impl.hpp
→
src/include/migraphx/
float
8_impl.hpp
View file @
ab653aff
...
@@ -20,49 +20,32 @@
...
@@ -20,49 +20,32 @@
*
*
* ************************************************************************ */
* ************************************************************************ */
#ifndef MIGRAPHX_FP8_IMPL_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <type_traits>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
#include <migraphx/config.hpp>
namespace
migraphx_f8_impl
{
#include <migraphx/bit_cast.hpp>
namespace
detail
{
namespace
migraphx
{
template
<
bool
B
,
class
T
,
class
F
>
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
conditional
namespace
fp8
{
{
namespace
impl
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
template
<
typename
To
,
typename
From
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
#if defined(__GNUC__) and !defined(__clang__)
return
MIGRAPHX_CONST_FOLD
(
*
reinterpret_cast
<
To
*>
(
&
fr
));
#else
return
__builtin_bit_cast
(
To
,
fr
);
#endif
}
}
// namespace detail
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
{
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// half is not supported for now
constexpr
bool
is_half
=
false
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
if
constexpr
(
sizeof
(
T
)
==
4
)
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
detail
::
bit_cast
<
uint32_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
_x
);
else
else
x
=
detail
::
bit_cast
<
uint16_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint16_t
>
(
_x
);
uint32_t
head
,
mantissa
;
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
int
exponent
,
bias
;
...
@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
...
@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
constexpr
T
cast_from_f8
(
uint8_t
x
)
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
{
constexpr
int
weo
=
8
;
// half is not supported for now
constexpr
int
wmo
=
23
;
constexpr
bool
is_half
=
false
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_float
or
is_half
,
"Only float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
uint32_t
ifInf
=
0x7F800000
;
uint32_t
ifNegInf
=
0xFF800000
;
uint32_t
ifNaN
=
0x7F800001
;
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
detail
::
bit_cast
<
float
>
(
ifInf
);
if
constexpr
(
is_float
)
fNegInf
=
detail
::
bit_cast
<
float
>
(
ifNegInf
);
{
fNaN
=
detail
::
bit_cast
<
float
>
(
ifNaN
);
const
uint32_t
ifInf
=
0x7F800000
;
fNeg0
=
detail
::
bit_cast
<
float
>
(
ifNeg0
);
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
migraphx
::
bit_cast
<
float
>
(
ifInf
);
fNegInf
=
migraphx
::
bit_cast
<
float
>
(
ifNegInf
);
fNaN
=
migraphx
::
bit_cast
<
float
>
(
ifNaN
);
fNeg0
=
migraphx
::
bit_cast
<
float
>
(
ifNeg0
);
}
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
...
@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
else
if
(
wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
else
if
(
wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
fNaN
;
return
fNaN
;
}
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
detail
::
bit_cast
<
T
>
(
retval
);
return
migraphx
::
bit_cast
<
T
>
(
retval
);
}
}
}
// namespace migraphx_f8_impl
}
// namespace impl
#endif // MIGRAPHX_FP8_IMPL_HPP
}
// namespace fp8
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
src/include/migraphx/half.hpp
View file @
ab653aff
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
...
...
src/include/migraphx/shape.hpp
View file @
ab653aff
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
...
@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
...
...
src/include/migraphx/type_traits.hpp
View file @
ab653aff
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
...
@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
template
<
class
T
,
class
U
>
struct
is_same
:
std
::
is_same
<
T
,
U
>
{
};
template
<
bool
B
,
class
T
,
class
U
>
struct
conditional
:
std
::
conditional
<
B
,
T
,
U
>
{
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
ab653aff
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
#endif
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
};
template
<
>
template
<
>
struct
npy_format_descriptor
<
migraphx
_
fp8
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
{
{
static
std
::
string
format
()
static
std
::
string
format
()
{
{
...
...
test/CMakeLists.txt
View file @
ab653aff
...
@@ -150,7 +150,7 @@ function(test_headers PREFIX)
...
@@ -150,7 +150,7 @@ function(test_headers PREFIX)
list
(
REMOVE_ITEM HEADERS
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
endif
()
endif
()
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_f
8_impl.hpp
)
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
float
8_impl.hpp
)
foreach
(
HEADER
${
HEADERS
}
)
foreach
(
HEADER
${
HEADERS
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
...
...
test/float_equal.cpp
View file @
ab653aff
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -72,12 +72,12 @@ void test_equality()
...
@@ -72,12 +72,12 @@ void test_equality()
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
int
>
);
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
void
test_limits
()
void
test_limits
()
...
@@ -115,12 +115,12 @@ void test_limits()
...
@@ -115,12 +115,12 @@ void test_limits()
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
#ifndef _WIN32
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
// On Windows, types int and long have the same min and max values.
...
...
test/fp8e4m3fn.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e4m3fn
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
...
@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE
(
test_negative_zero
)
TEST_CASE
(
test_negative_zero
)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
migraphx
_
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e4m3fn
// negative zero is preserved for fp8e4m3fn
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
...
@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
...
@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e4m3fn
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
...
@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to max()
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_infinity_2
)
TEST_CASE
(
test_infinity_2
)
...
@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
...
@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to lowest
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
()});
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
()});
}
}
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e4m3fnuz.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e4m3fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fnuz_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fnuz_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
...
@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
float
pzero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero gets converted to positive zero
// negative zero gets converted to positive zero
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
...
@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
...
@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e4m3fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
...
@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to Nans
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
...
@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
...
@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e5m2
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e5m2
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e5m2
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
...
@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE
(
test_negative_zero
)
TEST_CASE
(
test_negative_zero
)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
migraphx
_
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e5m2
// negative zero is preserved for fp8e5m2
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
...
@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
...
@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e5m2
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e5m2
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
...
@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
{
{
// float infinity should get clipped to max
// float infinity should get clipped to max
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
finf
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_infinity_2
)
TEST_CASE
(
test_infinity_2
)
...
@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
...
@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2, it gets clipped to lowest
// no inf in fp8e5m2, it gets clipped to lowest
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
finf
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
()});
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
()});
}
}
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2fnuz.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e5m2fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e5m2fnuz
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2fnuz_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2fnuz_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
...
@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
float
pzero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero gets converted to positive zero
// negative zero gets converted to positive zero
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
...
@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
...
@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e5m2fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e5m2fnuz
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
...
@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2fnuz it gets clipped to Nans
// no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
...
@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2fnuz it gets clipped to NaNs
// no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
...
@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
tools/api/migraphx.h
View file @
ab653aff
...
@@ -45,7 +45,7 @@
...
@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#ifdef __cplusplus
#ifdef __cplusplus
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment