Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a298c926
Commit
a298c926
authored
Nov 08, 2023
by
Umang Yadav
Browse files
deprecate pytorch implementation
parent
40c2df86
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
525 deletions
+0
-525
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+0
-525
No files found.
src/include/migraphx/fp8e4m3fnuz.hpp
deleted
100644 → 0
View file @
40c2df86
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FP8E4M3FNUZ_HPP
#define MIGRAPHX_GUARD_RTGLIB_FP8E4M3FNUZ_HPP
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
///
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
///
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <type_traits>
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <migraphx/config.hpp>
#include <string>
#include <utility>
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_PLATFORM_HCC__)
// MIGraphX by default does not have device code in the regular compilation paths,
// therefore, when this file is used from the host side, compilation takes much
// longer. By guarding the __device__ directive we can control that such compilation
// only happens for kernels which include this file.
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#include <hip/hip_runtime.h>
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#endif
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wabsolute-value"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
inline
MIGRAPHX_HIP_HOST_DEVICE
float
fp32_from_bits
(
uint32_t
w
)
{
union
{
uint32_t
as_bits
;
float
as_value
;
}
fp32
=
{
w
};
return
fp32
.
as_value
;
}
inline
MIGRAPHX_HIP_HOST_DEVICE
uint32_t
fp32_to_bits
(
float
f
)
{
union
{
float
as_value
;
uint32_t
as_bits
;
}
fp32
=
{
f
};
return
fp32
.
as_bits
;
}
/*
* Convert a 8-bit floating-point number in fp8 E4M3FNUZ format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
float
fp8e4m3fnuz_to_fp32_value
(
uint8_t
input
)
{
constexpr
float
e4m3fnuz_lut
[
256
]
=
{
0.0
f
,
0.0009765625
f
,
0.001953125
f
,
0.0029296875
f
,
0.00390625
f
,
0.0048828125
f
,
0.005859375
f
,
0.0068359375
f
,
0.0078125
f
,
0.0087890625
f
,
0.009765625
f
,
0.0107421875
f
,
0.01171875
f
,
0.0126953125
f
,
0.013671875
f
,
0.0146484375
f
,
0.015625
f
,
0.017578125
f
,
0.01953125
f
,
0.021484375
f
,
0.0234375
f
,
0.025390625
f
,
0.02734375
f
,
0.029296875
f
,
0.03125
f
,
0.03515625
f
,
0.0390625
f
,
0.04296875
f
,
0.046875
f
,
0.05078125
f
,
0.0546875
f
,
0.05859375
f
,
0.0625
f
,
0.0703125
f
,
0.078125
f
,
0.0859375
f
,
0.09375
f
,
0.1015625
f
,
0.109375
f
,
0.1171875
f
,
0.125
f
,
0.140625
f
,
0.15625
f
,
0.171875
f
,
0.1875
f
,
0.203125
f
,
0.21875
f
,
0.234375
f
,
0.25
f
,
0.28125
f
,
0.3125
f
,
0.34375
f
,
0.375
f
,
0.40625
f
,
0.4375
f
,
0.46875
f
,
0.5
f
,
0.5625
f
,
0.625
f
,
0.6875
f
,
0.75
f
,
0.8125
f
,
0.875
f
,
0.9375
f
,
1.0
f
,
1.125
f
,
1.25
f
,
1.375
f
,
1.5
f
,
1.625
f
,
1.75
f
,
1.875
f
,
2.0
f
,
2.25
f
,
2.5
f
,
2.75
f
,
3.0
f
,
3.25
f
,
3.5
f
,
3.75
f
,
4.0
f
,
4.5
f
,
5.0
f
,
5.5
f
,
6.0
f
,
6.5
f
,
7.0
f
,
7.5
f
,
8.0
f
,
9.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
16.0
f
,
18.0
f
,
20.0
f
,
22.0
f
,
24.0
f
,
26.0
f
,
28.0
f
,
30.0
f
,
32.0
f
,
36.0
f
,
40.0
f
,
44.0
f
,
48.0
f
,
52.0
f
,
56.0
f
,
60.0
f
,
64.0
f
,
72.0
f
,
80.0
f
,
88.0
f
,
96.0
f
,
104.0
f
,
112.0
f
,
120.0
f
,
128.0
f
,
144.0
f
,
160.0
f
,
176.0
f
,
192.0
f
,
208.0
f
,
224.0
f
,
240.0
f
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0009765625
f
,
-
0.001953125
f
,
-
0.0029296875
f
,
-
0.00390625
f
,
-
0.0048828125
f
,
-
0.005859375
f
,
-
0.0068359375
f
,
-
0.0078125
f
,
-
0.0087890625
f
,
-
0.009765625
f
,
-
0.0107421875
f
,
-
0.01171875
f
,
-
0.0126953125
f
,
-
0.013671875
f
,
-
0.0146484375
f
,
-
0.015625
f
,
-
0.017578125
f
,
-
0.01953125
f
,
-
0.021484375
f
,
-
0.0234375
f
,
-
0.025390625
f
,
-
0.02734375
f
,
-
0.029296875
f
,
-
0.03125
f
,
-
0.03515625
f
,
-
0.0390625
f
,
-
0.04296875
f
,
-
0.046875
f
,
-
0.05078125
f
,
-
0.0546875
f
,
-
0.05859375
f
,
-
0.0625
f
,
-
0.0703125
f
,
-
0.078125
f
,
-
0.0859375
f
,
-
0.09375
f
,
-
0.1015625
f
,
-
0.109375
f
,
-
0.1171875
f
,
-
0.125
f
,
-
0.140625
f
,
-
0.15625
f
,
-
0.171875
f
,
-
0.1875
f
,
-
0.203125
f
,
-
0.21875
f
,
-
0.234375
f
,
-
0.25
f
,
-
0.28125
f
,
-
0.3125
f
,
-
0.34375
f
,
-
0.375
f
,
-
0.40625
f
,
-
0.4375
f
,
-
0.46875
f
,
-
0.5
f
,
-
0.5625
f
,
-
0.625
f
,
-
0.6875
f
,
-
0.75
f
,
-
0.8125
f
,
-
0.875
f
,
-
0.9375
f
,
-
1.0
f
,
-
1.125
f
,
-
1.25
f
,
-
1.375
f
,
-
1.5
f
,
-
1.625
f
,
-
1.75
f
,
-
1.875
f
,
-
2.0
f
,
-
2.25
f
,
-
2.5
f
,
-
2.75
f
,
-
3.0
f
,
-
3.25
f
,
-
3.5
f
,
-
3.75
f
,
-
4.0
f
,
-
4.5
f
,
-
5.0
f
,
-
5.5
f
,
-
6.0
f
,
-
6.5
f
,
-
7.0
f
,
-
7.5
f
,
-
8.0
f
,
-
9.0
f
,
-
10.0
f
,
-
11.0
f
,
-
12.0
f
,
-
13.0
f
,
-
14.0
f
,
-
15.0
f
,
-
16.0
f
,
-
18.0
f
,
-
20.0
f
,
-
22.0
f
,
-
24.0
f
,
-
26.0
f
,
-
28.0
f
,
-
30.0
f
,
-
32.0
f
,
-
36.0
f
,
-
40.0
f
,
-
44.0
f
,
-
48.0
f
,
-
52.0
f
,
-
56.0
f
,
-
60.0
f
,
-
64.0
f
,
-
72.0
f
,
-
80.0
f
,
-
88.0
f
,
-
96.0
f
,
-
104.0
f
,
-
112.0
f
,
-
120.0
f
,
-
128.0
f
,
-
144.0
f
,
-
160.0
f
,
-
176.0
f
,
-
192.0
f
,
-
208.0
f
,
-
224.0
f
,
-
240.0
f
,
};
return
e4m3fnuz_lut
[
input
];
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
inline
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
fp8e4m3fnuz_from_fp32_value
(
float
f
)
{
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
* a NaN) in fp8e4m3fnuz range:
* 1 0000 000 - fp8e4m3fnuz
* 0 10000110 00000000000000000000000 - fp32
*/
constexpr
uint32_t
fnuz_max
=
UINT32_C
(
0x87
)
<<
23
;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denormalized representation.
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
0x8C
)
<<
23
;
uint32_t
f_bits
=
fp32_to_bits
(
f
);
uint32_t
result
=
0u
;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
/*
* Set sign bit to 0
*/
f_bits
^=
sign
;
if
(
f_bits
>=
fnuz_max
)
{
// NaN -- sign bit set to 1, rest 0s
return
0x80
;
}
if
(
f_bits
<
(
UINT32_C
(
121
)
<<
23
))
{
// Input number is smaller than 2^(-7), which is the smallest
// fp8e4m3fnuz normal number
f_bits
=
fp32_to_bits
(
fp32_from_bits
(
f_bits
)
+
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
if
(
result
==
0
)
{
// fnuz types don't have negative zero.
return
0
;
}
}
else
{
// resulting mantissa is odd
uint8_t
mant_odd
=
(
f_bits
>>
20
)
&
1
;
// update exponent, rounding bias part 1
f_bits
+=
((
uint32_t
)(
8
-
127
)
<<
23
)
+
0x7FFFF
;
// rounding bias part 2
f_bits
+=
mant_odd
;
// take the bits!
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
20
);
}
result
|=
sign
>>
24
;
return
result
;
}
struct
expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit
constexpr
expr
(
float
f
)
noexcept
:
value_
(
f
)
{}
/// Conversion to single-precision.
/// \return single precision value representing expression value
constexpr
operator
float
()
const
noexcept
{
return
value_
;
}
private:
/// Internal expression value stored in single-precision.
float
value_
;
};
}
// namespace detail
/*
overloads using migraphx::fp8e4m3fnuz may not be necessary since they can be implicitly casted to
float that is how half.hpp is implementing it.
this operators can't be friend since it leads to conflicting candidates with inbuilt operators (due
to implict cast to other types probably)
*/
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_UNARY_OP(unary_op) \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator unary_op( \
const migraphx::fp8e4m3fnuz& rhs) \
{ \
float y = float(data_); \
y unary_op float(rhs); \
data_ = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
} \
constexpr migraphx::fp8e4m3fnuz& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
float y = float(data_); \
y unary_op rhs; \
data_ = detail::fp8e4m3fnuz_from_fp32_value(y); \
return *this; \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T) \
friend constexpr T MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx::fp8e4m3fnuz& lhs, const migraphx::fp8e4m3fnuz& rhs) \
{ \
return T(float(lhs) binary_op float(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MATH(name, fname) \
migraphx::fp8e4m3fnuz MIGRAPHX_HIP_HOST_DEVICE name(migraphx::fp8e4m3fnuz x) \
{ \
return migraphx::fp8e4m3fnuz(fname(float(x))); \
}
}
// namespace MIGRAPHX_INLINE_NS
struct
alignas
(
1
)
fp8e4m3fnuz
{
uint8_t
data_
;
struct
from_bits_t
{
};
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
()
:
data_
(
0
)
{}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
fp8e4m3fnuz
(
uint8_t
bits
,
from_bits_t
)
:
data_
(
bits
)
{}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
const
fp8e4m3fnuz
&
y
)
=
default
;
inline
explicit
constexpr
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
float
value
)
:
data_
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
value
))
{
}
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
{
data_
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
return
*
this
;
}
#endif
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
data_
);
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
const
fp8e4m3fnuz
&
rhs
)
=
default
;
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
fp8e4m3fnuz
&&
rhs
)
=
default
;
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
data_
==
0b10000000
;
}
MIGRAPHX_FP8_UNARY_OP
(
+=
)
MIGRAPHX_FP8_UNARY_OP
(
-=
)
MIGRAPHX_FP8_UNARY_OP
(
*=
)
MIGRAPHX_FP8_UNARY_OP
(
/=
)
friend
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
fp8e4m3fnuz
&
value
)
{
out
<<
(
float
)(
value
);
return
out
;
}
// what should be the return type ?
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx
::
fp8e4m3fnuz
)
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
// implicit conversion should take care of these for the HOST side, half implementation doesn't
// have 'std' implementation MIGRAPHX_FP8_MATH(abs, ::abs) MIGRAPHX_FP8_MATH(acos, ::acos)
// if need to enable these functions, how to put them into std:: namespace ?
// MIGRAPHX_FP8_MATH(acosh, ::acosh)
// MIGRAPHX_FP8_MATH(asin, ::asin)
// MIGRAPHX_FP8_MATH(asinh, ::asinh)
// MIGRAPHX_FP8_MATH(atan, ::atan)
// MIGRAPHX_FP8_MATH(atanh, ::atanh)
// MIGRAPHX_FP8_MATH(ceil, ::ceil)
// MIGRAPHX_FP8_MATH(cos, ::cos)
// MIGRAPHX_FP8_MATH(cosh, ::cosh)
// MIGRAPHX_FP8_MATH(erf, ::erf)
// MIGRAPHX_FP8_MATH(exp, ::exp)
// MIGRAPHX_FP8_MATH(floor, ::floor)
// // MIGRAPHX_FP8_MATH(isnan, ::isnan)
// // MIGRAPHX_FP8_MATH(log, ::log)
// // MIGRAPHX_FP8_MATH(pow, ::pow)
// // MIGRAPHX_FP8_MATH(remainder, ::remainder)
// // MIGRAPHX_FP8_MATH(round, ::round)
// // MIGRAPHX_FP8_MATH(rsqrt, ::rsqrt)
// MIGRAPHX_FP8_MATH(sin, ::sin)
// MIGRAPHX_FP8_MATH(sinh, ::sinh)
// MIGRAPHX_FP8_MATH(sqrt, ::sqrt)
// MIGRAPHX_FP8_MATH(tan, ::tan)
// MIGRAPHX_FP8_MATH(tanh, ::tanh)
// // MIGRAPHX_FP8_MATH(fmod, ::fmod)
};
}
// namespace migraphx
namespace
std
{
template
<
>
class
numeric_limits
<
migraphx
::
fp8e4m3fnuz
>
{
public:
static
constexpr
bool
is_specialized
=
true
;
static
constexpr
bool
is_signed
=
true
;
static
constexpr
bool
is_integer
=
false
;
static
constexpr
bool
is_exact
=
false
;
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
bool
has_quiet_NaN
=
true
;
static
constexpr
bool
has_signaling_NaN
=
false
;
static
constexpr
auto
has_denorm
=
true
;
static
constexpr
auto
has_denorm_loss
=
true
;
static
constexpr
auto
round_style
=
numeric_limits
<
float
>::
round_style
;
static
constexpr
bool
is_iec559
=
false
;
static
constexpr
bool
is_bounded
=
true
;
static
constexpr
bool
is_modulo
=
false
;
static
constexpr
int
digits
=
4
;
static
constexpr
int
digits10
=
0
;
static
constexpr
int
max_digits10
=
3
;
static
constexpr
int
radix
=
2
;
static
constexpr
int
min_exponent
=
-
5
;
static
constexpr
int
min_exponent10
=
-
1
;
static
constexpr
int
max_exponent
=
8
;
static
constexpr
int
max_exponent10
=
2
;
static
constexpr
auto
traps
=
numeric_limits
<
float
>::
traps
;
static
constexpr
auto
tinyness_before
=
false
;
static
constexpr
migraphx
::
fp8e4m3fnuz
min
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x08
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
lowest
()
{
return
migraphx
::
fp8e4m3fnuz
(
0xFF
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
max
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x7F
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
epsilon
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x28
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
round_error
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x38
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
infinity
()
{
// NaN (no infinities)
return
migraphx
::
fp8e4m3fnuz
(
0x80
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
quiet_NaN
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x80
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
migraphx
::
fp8e4m3fnuz
denorm_min
()
{
return
migraphx
::
fp8e4m3fnuz
(
0x01
,
migraphx
::
fp8e4m3fnuz
::
from_bits
());
}
};
template
<
class
T
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
class
T
>
struct
common_type
<
T
,
migraphx
::
fp8e4m3fnuz
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
// template <>
// struct common_type<migraphx::fp8e4m3fnuz, migraphx::half>
// {
// using type = float;
// };
// template <>
// struct common_type<migraphx::half, migraphx::fp8e4m3fnuz>
// {
// using type = float;
// };
}
// namespace std
#pragma clang diagnostic pop
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment