Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b2bb524f
Commit
b2bb524f
authored
Nov 06, 2023
by
Umang Yadav
Browse files
Add friend overloads
parent
4a30c2d1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
47 deletions
+80
-47
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+77
-47
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+3
-0
No files found.
src/include/migraphx/fp8e4m3fnuz.hpp
View file @
b2bb524f
...
@@ -55,6 +55,11 @@
...
@@ -55,6 +55,11 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <string>
#include <string>
#include <utility>
#include <utility>
#if !defined(__HIP_NO_F8_CONVERSIONS__)
#include <migraphx/requires.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
// MIGraphX by default does not have device code in the regular compilation paths,
// MIGraphX by default does not have device code in the regular compilation paths,
...
@@ -72,6 +77,7 @@
...
@@ -72,6 +77,7 @@
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
namespace
detail
{
...
@@ -276,25 +282,65 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
...
@@ -276,25 +282,65 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
return
result
;
return
result
;
}
}
/// Temporary half-precision expression.
}
// namespace detail
/// This class represents a half-precision expression which just stores a single-precision value
/// internally.
struct
expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit
expr
(
float
f
)
:
value_
(
f
)
{}
/// Conversion to single-precision.
// NOLINTNEXTLINE
/// \return single precision value representing expression value
#define MIGRAPHX_FP8_BINARY_OP(op, binary_op) \
operator
float
()
const
{
return
value_
;
}
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op( \
const migraphx::fp8e4m3fnuz& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator op(const U& rhs) \
{ \
float y = float(x); \
y op float(rhs); \
x = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr float MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) binary_op float(rhs)); \
}
private:
// NOLINTNEXTLINE
/// Internal expression value stored in single-precision.
#define MIGRAPHX_FP8_COMP_OP(comp_op) \
float
value_
;
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
};
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return ((float)(lhs)comp_op(float)(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const migraphx::fp8e4m3fnuz& lhs, const U& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
} \
template <class U, MIGRAPHX_REQUIRES(migraphx::is_convertible<U, float>{})> \
friend constexpr bool MIGRAPHX_HIP_HOST_DEVICE operator comp_op( \
const U& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return (float(lhs) comp_op float(rhs)); \
}
}
// namespace
detail
}
// namespace
MIGRAPHX_INLINE_NS
struct
alignas
(
1
)
fp8e4m3fnuz
struct
alignas
(
1
)
fp8e4m3fnuz
{
{
...
@@ -335,41 +381,25 @@ struct alignas(1) fp8e4m3fnuz
...
@@ -335,41 +381,25 @@ struct alignas(1) fp8e4m3fnuz
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
+=
(
float
rhs
)
MIGRAPHX_FP8_BINARY_OP
(
+=
,
+
)
{
MIGRAPHX_FP8_BINARY_OP
(
-=
,
-
)
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
+
float
(
x
));
MIGRAPHX_FP8_BINARY_OP
(
*=
,
*
)
return
*
this
;
MIGRAPHX_FP8_BINARY_OP
(
/=
,
/
)
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
-=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
-
float
(
x
));
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
*=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
*
float
(
x
));
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
/=
(
float
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
/
float
(
x
));
return
*
this
;
}
};
MIGRAPHX_HIP_HOST_DEVICE
inline
migraphx
::
fp8e4m3fnuz
operator
+
(
migraphx
::
fp8e4m3fnuz
x
,
MIGRAPHX_FP8_COMP_OP
(
==
)
migraphx
::
fp8e4m3fnuz
y
)
MIGRAPHX_FP8_COMP_OP
(
!=
)
{
MIGRAPHX_FP8_COMP_OP
(
>=
)
return
migraphx
::
fp8e4m3fnuz
(
float
(
x
)
+
float
(
y
));
MIGRAPHX_FP8_COMP_OP
(
<=
)
}
MIGRAPHX_FP8_COMP_OP
(
>
)
MIGRAPHX_FP8_COMP_OP
(
<
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
fp8e4m3fnuz
&
value
)
friend
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
fp8e4m3fnuz
&
value
)
{
{
out
<<
(
float
)(
value
);
out
<<
(
float
)(
value
);
return
out
;
return
out
;
}
}
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
namespace
std
{
namespace
std
{
...
...
src/include/migraphx/requires.hpp
View file @
b2bb524f
...
@@ -38,6 +38,9 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
...
@@ -38,6 +38,9 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs or true)...>> // NOLINT
template
<
bool
B
>
template
<
bool
B
>
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
using
bool_c
=
std
::
integral_constant
<
bool
,
B
>
;
template
<
class
From
,
class
To
>
using
is_convertible
=
std
::
is_convertible
<
From
,
To
>
;
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
#define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment