Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
490ad9ba
Commit
490ad9ba
authored
Nov 06, 2023
by
Umang Yadav
Browse files
Keep return type of math ops in float and do not add math overloads
parent
0003e8a6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
74 deletions
+99
-74
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+99
-74
No files found.
src/include/migraphx/fp8e4m3fnuz.hpp
View file @
490ad9ba
...
...
@@ -55,11 +55,6 @@
#include <migraphx/config.hpp>
#include <string>
#include <utility>
#if !defined(__HIP_NO_F8_CONVERSIONS__)
#include <migraphx/requires.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
// MIGraphX by default does not have device code in the regular compilation paths,
...
...
@@ -78,6 +73,7 @@
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wabsolute-value"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
...
...
@@ -282,83 +278,81 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
return
result
;
}
struct
expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit
constexpr
expr
(
float
f
)
noexcept
:
value_
(
f
)
{}
/// Conversion to single-precision.
/// \return single precision value representing expression value
constexpr
operator
float
()
const
noexcept
{
return
value_
;
}
private:
/// Internal expression value stored in single-precision.
float
value_
;
};
}
// namespace detail
/*
overloads using migraphx::fp8e4m3fnuz may not be necessary since they can be implicitly casted to
float that is how half.hpp is implementing it.
this operators can't be friend since it leads to conflicting candidates with inbuilt operators (due
to implict cast to other types probably)
*/
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_
BI
NARY_OP(
op, bi
nary_op) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator
op(
\
#define MIGRAPHX_FP8_
U
NARY_OP(
u
nary_op)
\
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator
unary_op(
\
const migraphx::fp8e4m3fnuz& rhs) \
{ \
float y = float(
x);
\
y op float(rhs); \
x
= detail::fp8e4m3fnuz_from_fp32_value(y); \
float y = float(
data_);
\
y
unary_
op float(rhs);
\
data_
= detail::fp8e4m3fnuz_from_fp32_value(y);
\
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op(const U& rhs) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
float y = float(
x);
\
y
op float(rhs);
\
x
= detail::fp8e4m3fnuz_from_fp32_value(y); \
float y = float(
data_);
\
y
unary_op rhs;
\
data_
= detail::fp8e4m3fnuz_from_fp32_value(y);
\
return *this; \
} \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_
COMP_OP(comp_op)
\
friend constexpr
bool
MIGRAPHX_HIP_HOST_DEVICE operator
comp
_op(
\
#define MIGRAPHX_FP8_
BINARY_OP(binary_op, T)
\
friend constexpr
T
MIGRAPHX_HIP_HOST_DEVICE operator
binary
_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return ((float)(lhs)comp_op(float)(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
return T(float(lhs) binary_op float(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MATH(name, fname) \
migraphx::fp8e4m3fnuz MIGRAPHX_HIP_HOST_DEVICE name(migraphx::fp8e4m3fnuz x) \
{ \
return
(float(lhs) comp_op float(rhs));
\
return
migraphx::fp8e4m3fnuz(fname(float(x)));
\
}
}
// namespace MIGRAPHX_INLINE_NS
struct
alignas
(
1
)
fp8e4m3fnuz
{
uint8_t
x
;
uint8_t
data_
;
struct
from_bits_t
{
};
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
()
:
x
(
0
)
{}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
()
:
data_
(
0
)
{}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
fp8e4m3fnuz
(
uint8_t
bits
,
from_bits_t
)
:
x
(
bits
)
{}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
fp8e4m3fnuz
(
uint8_t
bits
,
from_bits_t
)
:
data_
(
bits
)
{}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
const
fp8e4m3fnuz
&
y
)
=
default
;
inline
explicit
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
float
value
)
:
x
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
value
))
inline
explicit
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
float
value
)
:
data_
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
value
))
{
}
...
...
@@ -367,37 +361,68 @@ struct alignas(1) fp8e4m3fnuz
// any type to any other type and that results in conflicts in candidate overload resolutions.
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
data_
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
return
*
this
;
}
#endif
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
x
);
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
data_
);
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
const
fp8e4m3fnuz
&
rhs
)
=
default
;
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
fp8e4m3fnuz
&&
rhs
)
=
default
;
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
MIGRAPHX_FP8_BINARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_BINARY_OP
(
-=
,
-
)
MIGRAPHX_FP8_BINARY_OP
(
*=
,
*
)
MIGRAPHX_FP8_BINARY_OP
(
/=
,
/
)
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
data_
==
0b10000000
;
}
MIGRAPHX_FP8_COMP_OP
(
==
)
MIGRAPHX_FP8_COMP_OP
(
!=
)
MIGRAPHX_FP8_COMP_OP
(
>=
)
MIGRAPHX_FP8_COMP_OP
(
<=
)
MIGRAPHX_FP8_COMP_OP
(
>
)
MIGRAPHX_FP8_COMP_OP
(
<
)
MIGRAPHX_FP8_UNARY_OP
(
+=
)
MIGRAPHX_FP8_UNARY_OP
(
-=
)
MIGRAPHX_FP8_UNARY_OP
(
*=
)
MIGRAPHX_FP8_UNARY_OP
(
/=
)
friend
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
fp8e4m3fnuz
&
value
)
{
out
<<
(
float
)(
value
);
return
out
;
}
// what should be the return type ?
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
// implicit conversion should take care of these for the HOST side, half implementation doesn't
// have 'std' implementation MIGRAPHX_FP8_MATH(abs, ::abs) MIGRAPHX_FP8_MATH(acos, ::acos)
// if need to enable these functions, how to put them into std:: namespace ?
// MIGRAPHX_FP8_MATH(acosh, ::acosh)
// MIGRAPHX_FP8_MATH(asin, ::asin)
// MIGRAPHX_FP8_MATH(asinh, ::asinh)
// MIGRAPHX_FP8_MATH(atan, ::atan)
// MIGRAPHX_FP8_MATH(atanh, ::atanh)
// MIGRAPHX_FP8_MATH(ceil, ::ceil)
// MIGRAPHX_FP8_MATH(cos, ::cos)
// MIGRAPHX_FP8_MATH(cosh, ::cosh)
// MIGRAPHX_FP8_MATH(erf, ::erf)
// MIGRAPHX_FP8_MATH(exp, ::exp)
// MIGRAPHX_FP8_MATH(floor, ::floor)
// // MIGRAPHX_FP8_MATH(isnan, ::isnan)
// // MIGRAPHX_FP8_MATH(log, ::log)
// // MIGRAPHX_FP8_MATH(pow, ::pow)
// // MIGRAPHX_FP8_MATH(remainder, ::remainder)
// // MIGRAPHX_FP8_MATH(round, ::round)
// // MIGRAPHX_FP8_MATH(rsqrt, ::rsqrt)
// MIGRAPHX_FP8_MATH(sin, ::sin)
// MIGRAPHX_FP8_MATH(sinh, ::sinh)
// MIGRAPHX_FP8_MATH(sqrt, ::sqrt)
// MIGRAPHX_FP8_MATH(tan, ::tan)
// MIGRAPHX_FP8_MATH(tanh, ::tanh)
// // MIGRAPHX_FP8_MATH(fmod, ::fmod)
};
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment