Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
78ec77ec
Commit
78ec77ec
authored
Nov 17, 2023
by
Umang Yadav
Browse files
only compile for device
parent
60942349
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
65 deletions
+54
-65
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+52
-63
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
...gets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
+2
-2
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
78ec77ec
...
...
@@ -30,19 +30,12 @@
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/hip.hpp>
#else
#include <hip/hip_runtime.h>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif // HIP_PLATFORM_AMD
#define MIGRAPHX_HIP_DEVICE __device__
...
...
@@ -91,15 +84,15 @@ struct float8
{
uint8_t
data
;
// default constructor
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
float8
()
=
default
;
MIGRAPHX_HIP_DEVICE
constexpr
float8
()
=
default
;
// default copy constructor
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
float8
(
const
float8
&
y
)
=
default
;
MIGRAPHX_HIP_DEVICE
constexpr
float8
(
const
float8
&
y
)
=
default
;
struct
from_bits_t
{
};
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
static
constexpr
MIGRAPHX_HIP_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_
HOST_
DEVICE
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
MIGRAPHX_HIP_DEVICE
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
...
...
@@ -176,12 +169,9 @@ struct float8
else
data
=
cast_to_f8_from_f32
<
false
>
(
v
);
}
// Host only implementation using s/w simulation
explicit
MIGRAPHX_HIP_HOST
#else
//
both Host and
DEVICE for non-gfx940 using s/w simulation
explicit
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
// DEVICE for non-gfx940 using s/w simulation
explicit
constexpr
MIGRAPHX_HIP_DEVICE
#endif
float8
(
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
...
...
@@ -215,7 +205,7 @@ struct float8
/*
// Constructor from half
explicit constexpr MIGRAPHX_HIP_
HOST_
DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(migraphx::half v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
...
...
@@ -225,7 +215,7 @@ struct float8
}
// constructor from int
explicit constexpr MIGRAPHX_HIP_
HOST_
DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(int v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
...
...
@@ -235,7 +225,7 @@ struct float8
}
// constructor from double
explicit constexpr MIGRAPHX_HIP_
HOST_
DEVICE
explicit constexpr MIGRAPHX_HIP_DEVICE
float8(double v,
migraphx::fp8::rounding_mode rm =
migraphx::fp8::rounding_mode::standard,
...
...
@@ -267,9 +257,8 @@ struct float8
return fval;
}
inline constexpr MIGRAPHX_HIP_HOST operator float() const
#else
// non gfx940
inline
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
operator
float
()
const
inline
constexpr
MIGRAPHX_HIP_DEVICE
operator
float
()
const
#endif
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
...
...
@@ -281,14 +270,14 @@ struct float8
/*
// convert to half
explicit inline MIGRAPHX_HIP_
HOST_
DEVICE operator migraphx::half() const
explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
*/
// check for zero
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
is_zero
()
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_zero
()
const
{
if
constexpr
(
FNUZ
)
{
...
...
@@ -301,7 +290,7 @@ struct float8
}
// check for nan
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
is_nan
()
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_nan
()
const
{
if
constexpr
(
FNUZ
)
{
...
...
@@ -325,7 +314,7 @@ struct float8
}
// check for inf
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
is_inf
()
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_inf
()
const
{
if
constexpr
(
FNUZ
)
{
...
...
@@ -345,13 +334,13 @@ struct float8
}
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr float8& MIGRAPHX_HIP_
HOST_
DEVICE operator unary_op(const float8& rhs) \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs)
\
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& MIGRAPHX_HIP_
HOST_
DEVICE operator unary_op(const float& rhs) \
constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs)
\
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
...
...
@@ -363,20 +352,20 @@ struct float8
MIGRAPHX_FP8_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_UNARY_OP
(
/=
,
/
)
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
=
default
;
inline
MIGRAPHX_HIP_DEVICE
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
MIGRAPHX_HIP_DEVICE
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
=
default
;
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline
constexpr
float8
&
MIGRAPHX_HIP_
HOST_
DEVICE
operator
=
(
float
rhs
)
inline
constexpr
float8
&
MIGRAPHX_HIP_DEVICE
operator
=
(
float
rhs
)
{
*
this
=
static_cast
<
float8
>
(
rhs
);
return
*
this
;
}
#endif
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
if
((
rhs
.
is_zero
()
&&
this
->
is_zero
())
||
(
fabs
(
rhs
-
*
this
)
<
migraphx
::
fp8
::
numeric_limits
<
float8
<
T
>>::
epsilon
()))
...
...
@@ -387,14 +376,14 @@ struct float8
return
false
;
}
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
<
them
;
}
inline
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
bool
operator
>
(
const
float8
&
rhs
)
const
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
>
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
...
...
@@ -412,12 +401,12 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8<T>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx::fp8::f8_type T> \
inline constexpr U MIGRAPHX_HIP_
HOST_
DEVICE operator binary_op(
\
const migraphx::fp8::float8<T>& lhs,
const migraphx::fp8::float8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U)
\
template <migraphx::fp8::f8_type T>
\
inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(
const migraphx::fp8::float8<T>& lhs,
\
const migraphx::fp8::float8<T>& rhs) \
{
\
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs));
\
}
// TODO: these should return floats
...
...
@@ -434,20 +423,20 @@ MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
template
<
migraphx
::
fp8
::
f8_type
T
>
inline
MIGRAPHX_HIP_
HOST_
DEVICE
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
inline
MIGRAPHX_HIP_DEVICE
migraphx
::
fp8
::
float8
<
T
>
fabs
(
migraphx
::
fp8
::
float8
<
T
>
v
)
{
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
}
template
<
class
T
>
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
T
F8_Max
()
MIGRAPHX_HIP_DEVICE
constexpr
T
F8_Max
()
{
return
T
{
0x7F
,
T
::
from_bits
()};
}
template
<
class
T
>
MIGRAPHX_HIP_
HOST_
DEVICE
constexpr
T
F8_Lowest
()
MIGRAPHX_HIP_DEVICE
constexpr
T
F8_Lowest
()
{
return
T
{
0xFF
,
T
::
from_bits
()};
}
...
...
@@ -462,27 +451,27 @@ class numeric_limits<fp8e4m3fnuz>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fnuz
epsilon
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fnuz
quiet_NaN
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
quiet_NaN
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fnuz
max
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
max
()
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fnuz
min
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fnuz
lowest
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
lowest
()
{
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
}
...
...
@@ -493,27 +482,27 @@ class numeric_limits<fp8e4m3fn>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fn
epsilon
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fn
quiet_NaN
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
quiet_NaN
()
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fn
max
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fn
min
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e4m3fn
lowest
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
lowest
()
{
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
}
...
...
@@ -524,28 +513,28 @@ class numeric_limits<fp8e5m2fnuz>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2fnuz
epsilon
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2fnuz
max
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
max
()
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2fnuz
min
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2fnuz
lowest
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
lowest
()
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
...
...
@@ -556,33 +545,33 @@ class numeric_limits<fp8e5m2>
{
public:
static
constexpr
bool
has_infinity
=
true
;
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
epsilon
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
quiet_NaN
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
quiet_NaN
()
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
// NOLINT
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
max
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
min
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
lowest
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
// 7C and FC both are infinity
static
constexpr
MIGRAPHX_HIP_
HOST_
DEVICE
fp8e5m2
infinity
()
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
View file @
78ec77ec
...
...
@@ -48,7 +48,7 @@ namespace fp8 {
namespace
impl
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
__device__
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
...
...
@@ -240,7 +240,7 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
)
__device__
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
int
weo
=
8
;
constexpr
int
wmo
=
23
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment