Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6155c782
"include/vscode:/vscode.git/clone" did not exist on "e576c0819f14ec43d5d3b49f1faae719326aa502"
Commit
6155c782
authored
Nov 20, 2023
by
Umang Yadav
Browse files
use __builtin_is_constant_evaluated
parent
7e3444ce
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
122 additions
and
101 deletions
+122
-101
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+114
-94
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+1
-1
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
6155c782
...
@@ -26,8 +26,8 @@
...
@@ -26,8 +26,8 @@
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#endif // __clang__
#endif // __clang__
#define MIGRAPHX_HIP_DEVICE __device__
// We are clipping in down conversion by default
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT
...
@@ -58,21 +58,21 @@ struct float8
...
@@ -58,21 +58,21 @@ struct float8
{
{
uint8_t
data
;
uint8_t
data
;
// default constructor
// default constructor
MIGRAPHX_HIP_DEVICE
constexpr
float8
()
=
default
;
__device__
constexpr
float8
()
=
default
;
// default copy constructor
// default copy constructor
MIGRAPHX_HIP_DEVICE
constexpr
float8
(
const
float8
&
y
)
=
default
;
__device__
constexpr
float8
(
const
float8
&
y
)
=
default
;
struct
from_bits_t
struct
from_bits_t
{
{
};
};
static
constexpr
MIGRAPHX_HIP_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
static
constexpr
__device__
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_DEVICE
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
__device__
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
// device specific optimized F8 down-conversion code
template
<
bool
stochastic_rounding
=
false
>
template
<
bool
stochastic_rounding
=
false
>
static
constexpr
MIGRAPHX_HIP_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
static
__device__
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
{
{
uint8_t
i8data
=
0x00
;
uint8_t
i8data
=
0x00
;
union
union
...
@@ -132,20 +132,50 @@ struct float8
...
@@ -132,20 +132,50 @@ struct float8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
// NOTE: ON-DEVICE... always optimal bias
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
float
v
,
float8
(
const
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
uint32_t
rng
=
0
)
{
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if
(
__builtin_is_constant_evaluated
())
if
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
)
{
data
=
cast_to_f8_from_f32
<
true
>
(
v
,
rng
);
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
}
}
else
else
data
=
cast_to_f8_from_f32
<
false
>
(
v
);
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
)
data
=
cast_to_f8_from_f32
<
true
>
(
v
,
rng
);
else
data
=
cast_to_f8_from_f32
<
false
>
(
v
);
}
}
}
#else
#else
// DEVICE for non-gfx940 using s/w simulation
// DEVICE for non-gfx940 using s/w simulation
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
float
v
,
float8
(
const
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
uint32_t
rng
=
0
)
...
@@ -178,64 +208,74 @@ struct float8
...
@@ -178,64 +208,74 @@ struct float8
#endif // __gfx940___
#endif // __gfx940___
// Constructor from half
// Constructor from half
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
_Float16
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
float8
(
const
_Float16
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
((
float
)
v
,
rm
,
rng
)
:
float8
((
float
)
v
,
rm
,
rng
)
{
{
}
}
// constructor from int
// constructor from int
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
int
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
float8
(
const
int
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
((
float
)
v
,
rm
,
rng
)
:
float8
((
float
)
v
,
rm
,
rng
)
{
{
}
}
// constructor from uint
// constructor from uint
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
uint32_t
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
float8
(
const
uint32_t
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
((
float
)
v
,
rm
,
rng
)
:
float8
((
float
)
v
,
rm
,
rng
)
{
{
}
}
// constructor from double
// constructor from double
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
double
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
float8
(
const
double
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
((
float
)
v
,
rm
,
rng
)
:
float8
((
float
)
v
,
rm
,
rng
)
{
{
}
}
// constructor from bool
// constructor from bool
explicit
constexpr
MIGRAPHX_HIP_DEVICE
explicit
constexpr
__device__
float8
(
const
bool
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
float8
(
const
bool
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
((
float
)(
v
),
rm
,
rng
)
:
float8
((
float
)(
v
),
rm
,
rng
)
{
{
}
}
// convert to float
// convert to float
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOLINT
#if 0 // need constexpr operator(). This version can't be constexpr // NOLINT
// upcast using device specific intrinsic
// upcast using device specific intrinsic
inline
MIGRAPHX_HIP_DEVICE
operator float() const
inline
constexpr
__device__
operator
float
()
const
{
{
float fval;
if
(
__builtin_is_constant_evaluated
())
uint32_t i32val = static_cast<uint32_t>(data);
// upcast
if constexpr(T == migraphx::fp8::f8_type::fp8)
{
{
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
}
else
else
{
{
asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
float
fval
=
0
;
}
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
__asm__
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
else
{
__asm__
volatile
(
"v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
return fval;
return
fval
;
}
}
}
#else // non gfx940
#else // non gfx940
inline
constexpr
MIGRAPHX_HIP_DEVICE
operator
float
()
const
inline
constexpr
__device__
operator
float
()
const
#endif
{
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
{
...
@@ -243,11 +283,12 @@ struct float8
...
@@ -243,11 +283,12 @@ struct float8
}
// else
}
// else
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
}
#endif
inline
constexpr
explicit
MIGRAPHX_HIP_DEVICE
operator
bool
()
const
{
return
not
is_zero
();
}
inline
constexpr
explicit
__device__
operator
bool
()
const
{
return
not
is_zero
();
}
// check for zero
// check for zero
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_zero
()
const
inline
__device__
constexpr
bool
is_zero
()
const
{
{
if
constexpr
(
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
...
@@ -260,7 +301,7 @@ struct float8
...
@@ -260,7 +301,7 @@ struct float8
}
}
// check for nan
// check for nan
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_nan
()
const
inline
__device__
constexpr
bool
is_nan
()
const
{
{
if
constexpr
(
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
...
@@ -281,7 +322,7 @@ struct float8
...
@@ -281,7 +322,7 @@ struct float8
}
}
// check for inf
// check for inf
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
is_inf
()
const
inline
__device__
constexpr
bool
is_inf
()
const
{
{
if
constexpr
(
FNUZ
)
if
constexpr
(
FNUZ
)
{
{
...
@@ -303,13 +344,13 @@ struct float8
...
@@ -303,13 +344,13 @@ struct float8
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \
#define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \
constexpr float8&
MIGRAPHX_HIP_DEVICE
operator unary_op(const float8& rhs) \
constexpr float8&
__device__
operator unary_op(const float8& rhs)
\
{ \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
*this = static_cast<float8>(tmp); \
return *this; \
return *this; \
} \
} \
constexpr float8&
MIGRAPHX_HIP_DEVICE
operator unary_op(const float& rhs) \
constexpr float8&
__device__
operator unary_op(const float& rhs)
\
{ \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
*this = static_cast<float8>(tmp); \
...
@@ -321,10 +362,10 @@ struct float8
...
@@ -321,10 +362,10 @@ struct float8
MIGRAPHX_FP8_SHORT_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
/=
,
/
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
/=
,
/
)
inline
MIGRAPHX_HIP_DEVICE
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
__device__
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
MIGRAPHX_HIP_DEVICE
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
__device__
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
inline
__device__
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
{
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
return
false
;
...
@@ -333,14 +374,14 @@ struct float8
...
@@ -333,14 +374,14 @@ struct float8
return
false
;
return
false
;
}
}
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
inline
__device__
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
{
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
<
them
;
return
we
<
them
;
}
}
inline
MIGRAPHX_HIP_DEVICE
constexpr
bool
operator
>
(
const
float8
&
rhs
)
const
inline
__device__
constexpr
bool
operator
>
(
const
float8
&
rhs
)
const
{
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
...
@@ -355,19 +396,19 @@ using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
...
@@ -355,19 +396,19 @@ using fp8e4m3fnuz = float8<migraphx::fp8::f8_type::fp8, true>;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U)
\
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U
MIGRAPHX_HIP_DEVICE
operator binary_op(const T& lhs, const T& rhs) \
inline constexpr U
__device__
operator binary_op(const T& lhs, const T& rhs) \
{
\
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs));
\
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T)
\
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr
MIGRAPHX_HIP_DEVICE
T fabs(T v) \
inline constexpr
__device__
T fabs(T v) \
{
\
{ \
/*NOLINTNEXTLINE*/
\
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f;
\
v.data = v.data & 0x7f; \
return v;
\
return v; \
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
...
@@ -394,27 +435,27 @@ class numeric_limits<fp8e4m3fnuz>
...
@@ -394,27 +435,27 @@ class numeric_limits<fp8e4m3fnuz>
{
{
public:
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
epsilon
()
static
constexpr
__device__
fp8e4m3fnuz
epsilon
()
{
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
quiet_NaN
()
static
constexpr
__device__
fp8e4m3fnuz
quiet_NaN
()
{
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
max
()
static
constexpr
__device__
fp8e4m3fnuz
max
()
{
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
min
()
static
constexpr
__device__
fp8e4m3fnuz
min
()
{
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fnuz
lowest
()
static
constexpr
__device__
fp8e4m3fnuz
lowest
()
{
{
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
}
}
...
@@ -425,27 +466,21 @@ class numeric_limits<fp8e4m3fn>
...
@@ -425,27 +466,21 @@ class numeric_limits<fp8e4m3fn>
{
{
public:
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
epsilon
()
static
constexpr
__device__
fp8e4m3fn
epsilon
()
{
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
quiet_NaN
()
static
constexpr
__device__
fp8e4m3fn
quiet_NaN
()
{
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
max
()
static
constexpr
__device__
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm. DeNorm min is 0x01
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
min
()
static
constexpr
__device__
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e4m3fn
lowest
()
static
constexpr
__device__
fp8e4m3fn
lowest
()
{
{
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
}
}
...
@@ -456,28 +491,28 @@ class numeric_limits<fp8e5m2fnuz>
...
@@ -456,28 +491,28 @@ class numeric_limits<fp8e5m2fnuz>
{
{
public:
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
epsilon
()
static
constexpr
__device__
fp8e5m2fnuz
epsilon
()
{
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
static
constexpr
__device__
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
{
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
max
()
static
constexpr
__device__
fp8e5m2fnuz
max
()
{
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
min
()
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2fnuz
lowest
()
static
constexpr
__device__
fp8e5m2fnuz
lowest
()
{
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
}
...
@@ -488,36 +523,21 @@ class numeric_limits<fp8e5m2>
...
@@ -488,36 +523,21 @@ class numeric_limits<fp8e5m2>
{
{
public:
public:
static
constexpr
bool
has_infinity
=
true
;
static
constexpr
bool
has_infinity
=
true
;
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
epsilon
()
static
constexpr
__device__
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
quiet_NaN
()
// NOLINT
static
constexpr
__device__
fp8e5m2
quiet_NaN
()
// NOLINT
{
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
max
()
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this distinction. For the floating points we would end up using lowest most of the times.
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
min
()
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
lowest
()
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
// 7C and FC both are infinity
// 7C and FC both are infinity
static
constexpr
MIGRAPHX_HIP_DEVICE
fp8e5m2
infinity
()
static
constexpr
__device__
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
};
}
// namespace fp8
}
// namespace fp8
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
6155c782
...
@@ -52,13 +52,14 @@ __device__ void generic_binary_layernorm(
...
@@ -52,13 +52,14 @@ __device__ void generic_binary_layernorm(
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
using
value_type
=
typename
Input1
::
type
;
using
vec_value_type
=
vec_type
<
value_type
>
;
using
vec_value_type
=
vec_type
<
value_type
>
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
relements_r
=
vec_value_type
{
1.0
/
relements
}
;
constexpr
auto
relements_r
=
static_cast
<
vec_value_type
>
(
1.0
/
relements
)
;
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_value_type
>
(
vec_value_type
{
0
},
vec_value_type
{
0
}),
make_array
<
vec_value_type
>
(
static_cast
<
vec_value_type
>
(
0
),
static_cast
<
vec_value_type
>
(
0
)),
[
&
](
auto
x
)
{
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing
// dividing x by sqrt(relements) before squaring allows computing
...
@@ -70,7 +71,7 @@ __device__ void generic_binary_layernorm(
...
@@ -70,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto
mean_x
=
means
[
0
];
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
value_type
{
eps
}
;
value_type
eps_val
=
static_cast
<
value_type
>
(
eps
)
;
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
auto
m
=
x
-
mean_x
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
6155c782
...
@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
...
@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
otype
{
x
/
batch_sum
}
;
})(
output
,
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
static_cast
<
otype
>
(
x
/
batch_sum
)
;
})(
output
,
exp_in
);
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment