Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ab653aff
"docs/vscode:/vscode.git/clone" did not exist on "2b981012a6eb27d566f03cf61c06b1ef7a522f27"
Commit
ab653aff
authored
Nov 13, 2023
by
Umang Yadav
Browse files
Review updates
parent
183db78a
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
261 additions
and
241 deletions
+261
-241
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-1
src/include/migraphx/bit_cast.hpp
src/include/migraphx/bit_cast.hpp
+42
-0
src/include/migraphx/float8.hpp
src/include/migraphx/float8.hpp
+66
-72
src/include/migraphx/float8_impl.hpp
src/include/migraphx/float8_impl.hpp
+42
-48
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-14
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/float_equal.cpp
test/float_equal.cpp
+9
-9
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+23
-23
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+21
-21
test/fp8e5m2.cpp
test/fp8e5m2.cpp
+23
-23
test/fp8e5m2fnuz.cpp
test/fp8e5m2fnuz.cpp
+21
-21
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/api/include/migraphx/migraphx.h
View file @
ab653aff
...
@@ -45,7 +45,7 @@
...
@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#ifdef __cplusplus
#ifdef __cplusplus
...
...
src/include/migraphx/bit_cast.hpp
0 → 100644
View file @
ab653aff
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <migraphx/config.hpp>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
typename
To
,
typename
From
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
#if defined(__GNUC__) and !defined(__clang__)
return
MIGRAPHX_CONST_FOLD
(
*
reinterpret_cast
<
To
*>
(
&
fr
));
#else
return
__builtin_bit_cast
(
To
,
fr
);
#endif
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
src/include/migraphx/
migraphx_
float8.hpp
→
src/include/migraphx/float8.hpp
View file @
ab653aff
...
@@ -44,20 +44,12 @@
...
@@ -44,20 +44,12 @@
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/float8_impl.hpp>
namespace
migraphx_f8_impl
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
namespace
fp8
{
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
constexpr
T
cast_from_f8
(
uint8_t
x
);
}
// namespace migraphx_f8_impl
#include <migraphx/migraphx_f8_impl.hpp>
namespace
migraphx_fp8
{
enum
class
migraphx_f8_rounding_mode
enum
class
migraphx_f8_rounding_mode
{
{
...
@@ -74,7 +66,7 @@ enum class f8_type
...
@@ -74,7 +66,7 @@ enum class f8_type
template
<
typename
T
,
bool
FNUZ
=
true
>
template
<
typename
T
,
bool
FNUZ
=
true
>
class
numeric_limits
;
class
numeric_limits
;
template
<
migraphx
_
fp8
::
f8_type
T
=
migraphx
_
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
template
<
migraphx
::
fp8
::
f8_type
T
=
migraphx
::
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
struct
float8
struct
float8
{
{
uint8_t
data
=
0x00
;
uint8_t
data
=
0x00
;
...
@@ -90,43 +82,43 @@ struct float8
...
@@ -90,43 +82,43 @@ struct float8
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
explicit
constexpr
float8
(
float
v
,
explicit
constexpr
float8
(
float
v
,
migraphx
_
fp8
::
migraphx_f8_rounding_mode
rm
=
migraphx
::
fp8
::
migraphx_f8_rounding_mode
rm
=
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
standard
,
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
uint32_t
rng
=
0
)
{
{
if
constexpr
(
T
==
migraphx
_
fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
}
else
else
{
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
_f8_
impl
::
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
_
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
v
,
(
rm
==
migraphx
::
fp8
::
migraphx_f8_rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
#endif // rocblas_F8_downcast_clipping}
}
}
}
}
inline
constexpr
operator
float
()
const
inline
constexpr
operator
float
()
const
{
{
if
constexpr
(
T
==
migraphx
_
fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
{
return
migraphx
_f8_
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
}
// else
return
migraphx
_f8_
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
}
inline
constexpr
bool
is_zero
()
const
inline
constexpr
bool
is_zero
()
const
...
@@ -149,7 +141,7 @@ struct float8
...
@@ -149,7 +141,7 @@ struct float8
}
}
else
else
{
{
if
(
T
==
migraphx
_
fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7D
)
or
(
data
==
0x7E
)
or
(
data
==
0x7F
)
or
(
data
==
0xFD
)
or
return
(
data
==
0x7D
)
or
(
data
==
0x7E
)
or
(
data
==
0x7F
)
or
(
data
==
0xFD
)
or
(
data
==
0xFE
)
or
(
data
==
0xFF
);
(
data
==
0xFE
)
or
(
data
==
0xFF
);
...
@@ -169,7 +161,7 @@ struct float8
...
@@ -169,7 +161,7 @@ struct float8
}
}
else
else
{
{
if
(
T
==
migraphx
_
fp8
::
f8_type
::
bf8
)
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
{
return
(
data
==
0x7C
)
or
(
data
==
0xFC
);
return
(
data
==
0x7C
)
or
(
data
==
0xFC
);
}
}
...
@@ -236,26 +228,26 @@ struct float8
...
@@ -236,26 +228,26 @@ struct float8
};
};
// Special operator overloading
// Special operator overloading
template
<
migraphx
_
fp8
::
f8_type
T
>
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx
_
fp8
::
float8
<
T
>&
rhs
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx
::
fp8
::
float8
<
T
>&
rhs
)
{
{
return
os
<<
static_cast
<
float
>
(
rhs
);
return
os
<<
static_cast
<
float
>
(
rhs
);
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U)
\
template <migraphx
_
fp8::f8_type T> \
template <migraphx
::
fp8::f8_type T> \
inline constexpr U operator binary_op(const migraphx
_
fp8::float8<T>& lhs, \
inline constexpr U operator binary_op(const migraphx
::
fp8::float8<T>& lhs, \
const migraphx
_
fp8::float8<T>& rhs) \
const migraphx
::
fp8::float8<T>& rhs) \
{ \
{
\
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs));
\
}
}
// TODO: these should return floats
// TODO: these should return floats
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
_
fp8
::
float8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8
::
float8
<
T
>
)
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding
// effects.
// effects.
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
...
@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
...
@@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
template
<
migraphx
_
fp8
::
f8_type
T
>
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
migraphx
_
fp8
::
float8
<
T
>
fabs
(
migraphx
_
fp8
::
float8
<
T
>
v
)
inline
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
{
{
v
.
data
=
v
.
data
&
0x7f
;
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
return
v
;
}
}
// https://onnx.ai/onnx/technical/float8.html
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
_
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
_
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
_
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
_
fp8
::
f8_type
::
bf8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
template
<
>
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
class
numeric_limits
<
fp8e4m3fnuz
>
...
@@ -347,37 +339,39 @@ class numeric_limits<fp8e5m2>
...
@@ -347,37 +339,39 @@ class numeric_limits<fp8e5m2>
// 7C and FC both are infinity
// 7C and FC both are infinity
static
constexpr
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
static
constexpr
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
};
}
// namespace migraphx_fp8
}
// namespace fp8
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
// =================================================================================================
// =================================================================================================
// define numeric limits for the new data type
// define numeric limits for the new data type
namespace
std
{
namespace
std
{
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
#define MIGRAPHX_FP8_STD_OVERLOADS(T)
\
inline bool isfinite(T x) { return x.is_inf(); } \
inline bool isfinite(T x) { return x.is_inf(); }
\
inline bool isnan(T x) { return x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); }
\
template <> \
template <>
\
class numeric_limits<T> : public migraphx
_
fp8::numeric_limits<T> \
class numeric_limits<T> : public migraphx
::
fp8::numeric_limits<T> \
{ \
{
\
}; \
};
\
template <class U> \
template <class U>
\
struct common_type<T, U> : std::common_type<float, U> \
struct common_type<T, U> : std::common_type<float, U>
\
{ \
{
\
}; \
};
\
template <class U> \
template <class U>
\
struct common_type<U, T> : std::common_type<float, U> \
struct common_type<U, T> : std::common_type<float, U>
\
{ \
{
\
}; \
};
\
template <> \
template <>
\
struct common_type<T, T> \
struct common_type<T, T>
\
{ \
{
\
using type = T; \
using type = T;
\
};
};
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e4m3fn
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fn
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
_
fp8
::
fp8e5m2fnuz
)
MIGRAPHX_FP8_STD_OVERLOADS
(
migraphx
::
fp8
::
fp8e5m2fnuz
)
}
// namespace std
}
// namespace std
// =================================================================================================
// =================================================================================================
...
...
src/include/migraphx/
migraphx_f
8_impl.hpp
→
src/include/migraphx/
float
8_impl.hpp
View file @
ab653aff
...
@@ -20,49 +20,32 @@
...
@@ -20,49 +20,32 @@
*
*
* ************************************************************************ */
* ************************************************************************ */
#ifndef MIGRAPHX_FP8_IMPL_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#define MIGRAPHX_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP
#include <type_traits>
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))
#include <migraphx/config.hpp>
namespace
migraphx_f8_impl
{
#include <migraphx/bit_cast.hpp>
namespace
detail
{
namespace
migraphx
{
template
<
bool
B
,
class
T
,
class
F
>
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
conditional
namespace
fp8
{
{
namespace
impl
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
template
<
typename
To
,
typename
From
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
#if defined(__GNUC__) and !defined(__clang__)
return
MIGRAPHX_CONST_FOLD
(
*
reinterpret_cast
<
To
*>
(
&
fr
));
#else
return
__builtin_bit_cast
(
To
,
fr
);
#endif
}
}
// namespace detail
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
{
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// half is not supported for now
constexpr
bool
is_half
=
false
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
x
;
if
constexpr
(
sizeof
(
T
)
==
4
)
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
detail
::
bit_cast
<
uint32_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
_x
);
else
else
x
=
detail
::
bit_cast
<
uint16_t
>
(
_x
);
x
=
migraphx
::
bit_cast
<
uint16_t
>
(
_x
);
uint32_t
head
,
mantissa
;
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
int
exponent
,
bias
;
...
@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
...
@@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
constexpr
T
cast_from_f8
(
uint8_t
x
)
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
{
constexpr
int
weo
=
8
;
// half is not supported for now
constexpr
int
wmo
=
23
;
constexpr
bool
is_half
=
false
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_float
or
is_half
,
"Only float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
uint32_t
ifInf
=
0x7F800000
;
uint32_t
ifNegInf
=
0xFF800000
;
uint32_t
ifNaN
=
0x7F800001
;
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
detail
::
bit_cast
<
float
>
(
ifInf
);
if
constexpr
(
is_float
)
fNegInf
=
detail
::
bit_cast
<
float
>
(
ifNegInf
);
{
fNaN
=
detail
::
bit_cast
<
float
>
(
ifNaN
);
const
uint32_t
ifInf
=
0x7F800000
;
fNeg0
=
detail
::
bit_cast
<
float
>
(
ifNeg0
);
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
migraphx
::
bit_cast
<
float
>
(
ifInf
);
fNegInf
=
migraphx
::
bit_cast
<
float
>
(
ifNegInf
);
fNaN
=
migraphx
::
bit_cast
<
float
>
(
ifNaN
);
fNeg0
=
migraphx
::
bit_cast
<
float
>
(
ifNeg0
);
}
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
...
@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x)
else
if
(
wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
else
if
(
wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
fNaN
;
return
fNaN
;
}
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
...
@@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
return
detail
::
bit_cast
<
T
>
(
retval
);
return
migraphx
::
bit_cast
<
T
>
(
retval
);
}
}
}
// namespace migraphx_f8_impl
}
// namespace impl
#endif // MIGRAPHX_FP8_IMPL_HPP
}
// namespace fp8
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL
src/include/migraphx/half.hpp
View file @
ab653aff
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
...
...
src/include/migraphx/shape.hpp
View file @
ab653aff
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
...
@@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
...
...
src/include/migraphx/type_traits.hpp
View file @
ab653aff
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
...
@@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
template
<
class
T
,
class
U
>
struct
is_same
:
std
::
is_same
<
T
,
U
>
{
};
template
<
bool
B
,
class
T
,
class
U
>
struct
conditional
:
std
::
conditional
<
B
,
T
,
U
>
{
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
_
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
ab653aff
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
#endif
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
};
template
<
>
template
<
>
struct
npy_format_descriptor
<
migraphx
_
fp8
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
{
{
static
std
::
string
format
()
static
std
::
string
format
()
{
{
...
...
test/CMakeLists.txt
View file @
ab653aff
...
@@ -150,7 +150,7 @@ function(test_headers PREFIX)
...
@@ -150,7 +150,7 @@ function(test_headers PREFIX)
list
(
REMOVE_ITEM HEADERS
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
endif
()
endif
()
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_f
8_impl.hpp
)
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
float
8_impl.hpp
)
foreach
(
HEADER
${
HEADERS
}
)
foreach
(
HEADER
${
HEADERS
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
...
...
test/float_equal.cpp
View file @
ab653aff
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -72,12 +72,12 @@ void test_equality()
...
@@ -72,12 +72,12 @@ void test_equality()
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
int
>
);
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
void
test_limits
()
void
test_limits
()
...
@@ -115,12 +115,12 @@ void test_limits()
...
@@ -115,12 +115,12 @@ void test_limits()
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
_
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
#ifndef _WIN32
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
// On Windows, types int and long have the same min and max values.
...
...
test/fp8e4m3fn.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e4m3fn
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
...
@@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE
(
test_negative_zero
)
TEST_CASE
(
test_negative_zero
)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
migraphx
_
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e4m3fn
// negative zero is preserved for fp8e4m3fn
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
...
@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
...
@@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e4m3fn
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
...
@@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to max()
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_infinity_2
)
TEST_CASE
(
test_infinity_2
)
...
@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
...
@@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to lowest
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
()});
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
()});
}
}
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fn
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e4m3fnuz.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e4m3fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fnuz_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fnuz_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
...
@@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
float
pzero
=
0.0
;
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero gets converted to positive zero
// negative zero gets converted to positive zero
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
...
@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
...
@@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e4m3fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
...
@@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to Nans
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
...
@@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
...
@@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
();
migraphx
_
fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e4m3fnuz
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e5m2
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e5m2
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e5m2
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
...
@@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero)
TEST_CASE
(
test_negative_zero
)
TEST_CASE
(
test_negative_zero
)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
migraphx
_
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e5m2
// negative zero is preserved for fp8e5m2
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
...
@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
...
@@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e5m2
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e5m2
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e5m2
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
...
@@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1)
{
{
// float infinity should get clipped to max
// float infinity should get clipped to max
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
finf
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_infinity_2
)
TEST_CASE
(
test_infinity_2
)
...
@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
...
@@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2, it gets clipped to lowest
// no inf in fp8e5m2, it gets clipped to lowest
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
finf
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
()});
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
()});
}
}
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
();
migraphx
_
fp8
::
fp8e5m2
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e5m2
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e5m2fnuz.cpp
View file @
ab653aff
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <cmath>
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/
migraphx_
float8.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float)
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_val
(
bit_val
,
migraphx
_
fp8
::
fp8e5m2fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e5m2fnuz
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2fnuz_to_fp32_value
(
bit_val
)))
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e5m2fnuz_to_fp32_value
(
bit_val
)))
{
{
return
true
;
return
true
;
...
@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
...
@@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float)
TEST_CASE
(
test_positive_zero
)
TEST_CASE
(
test_positive_zero
)
{
{
float
zero
=
0.0
;
float
zero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_zero
(
zero
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
}
...
@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
...
@@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero)
{
{
float
nzero
=
-
0.0
;
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
float
pzero
=
0.0
;
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero gets converted to positive zero
// negative zero gets converted to positive zero
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
...
@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
...
@@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero)
TEST_CASE
(
test_nan_1
)
TEST_CASE
(
test_nan_1
)
{
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
}
TEST_CASE
(
test_nan_2
)
TEST_CASE
(
test_nan_2
)
{
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
quiet_NaN
();
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
quiet_NaN
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
.
data
,
migraphx
_
fp8
::
fp8e5m2fnuz
::
from_bits
());
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e5m2fnuz
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
...
@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
...
@@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1)
{
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2fnuz it gets clipped to Nans
// no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
...
@@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2)
// neg inf
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e5m2fnuz it gets clipped to NaNs
// no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
}
...
@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
...
@@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2)
TEST_CASE
(
test_numeric_max_1
)
TEST_CASE
(
test_numeric_max_1
)
{
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_max_2
)
TEST_CASE
(
test_numeric_max_2
)
{
{
// gets clipped to max
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
();
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
());
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
());
}
}
TEST_CASE
(
test_numeric_lowest_1
)
TEST_CASE
(
test_numeric_lowest_1
)
{
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_lowest
(
flowest
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_numeric_lowest_2
)
TEST_CASE
(
test_numeric_lowest_2
)
{
{
// gets clipped to lowest
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
();
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
();
migraphx
_
fp8
::
fp8e5m2fnuz
fp8_lowest
(
fmin
);
migraphx
::
fp8
::
fp8e5m2fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
());
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
());
}
}
TEST_CASE
(
test_max_eq_lowest
)
TEST_CASE
(
test_max_eq_lowest
)
{
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
lowest
(),
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
_
fp8
::
fp8e5m2fnuz
>::
max
()));
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e5m2fnuz
>::
max
()));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
tools/api/migraphx.h
View file @
ab653aff
...
@@ -45,7 +45,7 @@
...
@@ -45,7 +45,7 @@
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx
_
fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx
::
fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#ifdef __cplusplus
#ifdef __cplusplus
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment