Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a8edad5
Commit
0a8edad5
authored
Nov 08, 2023
by
Umang Yadav
Browse files
works except constexpr
parent
d734871c
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
388 additions
and
384 deletions
+388
-384
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+325
-311
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+28
-39
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+10
-10
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+2
-2
test/gpu/jit.cpp
test/gpu/jit.cpp
+2
-2
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/include/migraphx/half.hpp
View file @
0a8edad5
...
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
_fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
...
...
src/include/migraphx/migraphx_float8.hpp
View file @
0a8edad5
...
...
@@ -20,19 +20,38 @@
*
* ************************************************************************ */
#ifndef MIGRAPHX_FLOAT8_HPP
#define MIGRAPHX_FLOAT8_HPP
#ifndef MIGRAPHX_
GUARD_RTGLIB_
FLOAT8_HPP
#define MIGRAPHX_
GUARD_RTGLIB_
FLOAT8_HPP
#ifdef __HIP_PLATFORM_HCC__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#else
#include <migraphx/kernels/hip.hpp>
#endif
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#define MIGRAPHX_HIP_HOST __host__
#else
#define MIGRAPHX_HIP_HOST_DEVICE
#define MIGRAPHX_HIP_HOST
#endif
#define MIGRAPHX_HIP_HOST __host__
#define MIGRAPHX_HIP_DEVICE __device__
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
#ifndef __HIPCC_RTC__
#include <cmath>
#include <cstdint>
...
...
@@ -44,28 +63,25 @@
#include <iostream>
#include <string>
#include <utility>
#include <migraphx/type_traits.hpp>
#else
#include <migraphx/kernels/type_traits.hpp>
#endif
namespace
migraphx_hip_f8_impl
{
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
);
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
);
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
);
}
// namespace migraphx_hip_f8_impl
#include
"
migraphx_hip_f8_impl.hpp
"
#include
<migraphx/
migraphx_hip_f8_impl.hpp
>
namespace
migraphx_fp8
{
enum
class
migraphx_hip_f8_rounding_mode
{
standard
,
standard
,
// standard rounding is doing RNE -- round to nearest even
stochastic
};
...
...
@@ -76,11 +92,19 @@ enum class hip_f8_type
};
template
<
migraphx_fp8
::
hip_f8_type
T
=
migraphx_fp8
::
hip_f8_type
::
fp8
>
struct
MIGRAPHX_EXPORT
migraphx
_f8
struct
hip
_f8
{
uint8_t
data
;
// default constructor
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
()
=
default
;
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
()
=
default
;
// default copy constructor
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
(
const
hip_f8
&
y
)
=
default
;
struct
from_bits_t
{
};
static
constexpr
MIGRAPHX_HIP_HOST_DEVICE
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
...
...
@@ -121,10 +145,8 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
// RNE CVT
else
// RNE CVT
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
...
...
@@ -135,11 +157,12 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
return
i8data
;
}
#endif // __gfx940__
...
...
@@ -147,11 +170,10 @@ struct MIGRAPHX_EXPORT migraphx_f8
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
explicit
MIGRAPHX_HIP_DEVICE
migraphx_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
explicit
MIGRAPHX_HIP_DEVICE
hip_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
)
...
...
@@ -164,22 +186,22 @@ struct MIGRAPHX_EXPORT migraphx_f8
explicit
MIGRAPHX_HIP_HOST
#else
// both Host and DEVICE for non-gfx940 using s/w simulation
explicit
MIGRAPHX_HIP_HOST_DEVICE
explicit
constexpr
MIGRAPHX_HIP_HOST_DEVICE
#endif
migraphx
_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
hip
_f8
(
float
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
...
...
@@ -187,49 +209,53 @@ struct MIGRAPHX_EXPORT migraphx_f8
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx_hip_f8_impl
::
cast_to_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
,
false
/*clip*/
>
(
cast_to_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
stochastic
),
rng
);
#endif // rocblas_F8_downcast_clipping}
}
}
// Constructor from half
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx_f8
(
migraphx
::
half
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx_f8
((
float
)
v
,
rm
,
rng
)
{
}
/*
// Constructor from half
explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
hip_f8(migraphx::half v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
: hip_f8((float)v, rm, rng)
{
}
// constructor from int
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx
_f8
(
int
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx
_f8
((
float
)
v
,
rm
,
rng
)
explicit
constexpr
MIGRAPHX_HIP_HOST_DEVICE
hip
_f8(int v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
:
hip
_f8((float)v, rm, rng)
{
}
// constructor from double
explicit
MIGRAPHX_HIP_HOST_DEVICE
migraphx
_f8
(
double
v
,
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
rm
=
migraphx_fp8
::
migraphx_hip_f8_rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
migraphx
_f8
((
float
)
v
,
rm
,
rng
)
explicit
constexpr
MIGRAPHX_HIP_HOST_DEVICE
hip
_f8(double v,
migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
uint32_t rng = 0)
:
hip
_f8((float)v, rm, rng)
{
}
*/
/**/
// convert to float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
// upcast using device specific intrinsic
explicit
inline
MIGRAPHX_HIP_DEVICE
operator
float
()
const
inline MIGRAPHX_HIP_DEVICE operator float() const
{
float fval;
uint32_t i32val = static_cast<uint32_t>(data);
...
...
@@ -247,291 +273,195 @@ struct MIGRAPHX_EXPORT migraphx_f8
return fval;
}
explicit
inline
MIGRAPHX_HIP_HOST
operator
float
()
const
inline constexpr
MIGRAPHX_HIP_HOST operator float() const
#else
// non gfx940
explicit
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
#endif
{
if
constexpr
(
T
==
migraphx_fp8
::
hip_f8_type
::
fp8
)
{
return
migraphx_hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
return
migraphx_hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx_hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
return
migraphx_hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
MIGRAPHX_FP8_FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// convert to half
explicit
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
migraphx
::
half
()
const
{
return
migraphx
::
half
(
float
(
*
this
));
// convert to float, then convert to f16
}
/*
// convert to half
explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
{
return migraphx::half(float(*this)); // convert to float, then convert to f16
}
*/
// check for zero
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_zero
()
const
{
return
data
==
0x00
;
}
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_zero
()
const
{
if
constexpr
(
MIGRAPHX_FP8_FNUZ
)
{
return
data
==
0x00
;
}
else
{
return
(
data
==
0x00
)
||
(
data
==
0x80
);
}
}
// check for nan
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_nan
()
const
{
return
data
==
0x80
;
}
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_nan
()
const
{
if
constexpr
(
MIGRAPHX_FP8_FNUZ
)
{
return
data
==
0x80
;
}
else
{
if
(
T
==
migraphx_fp8
::
hip_f8_type
::
bf8
)
{
return
(
data
==
0x7d
)
||
(
data
==
0x7e
)
||
(
data
==
0x7f
)
||
(
data
==
0xfd
)
||
(
data
==
0xfe
)
||
(
data
==
0xff
);
}
else
{
return
(
data
==
0x79
)
||
(
data
==
0x7a
)
||
(
data
==
0x7b
)
||
(
data
==
0x7c
)
||
(
data
==
0x7d
)
||
(
data
==
0x7e
)
||
(
data
==
0x7f
)
||
(
data
==
0xf9
)
||
(
data
==
0xfa
)
||
(
data
==
0xfb
)
||
(
data
==
0xfc
)
||
(
data
==
0xfd
)
||
(
data
==
0xfe
)
||
(
data
==
0xff
);
}
}
}
// check for inf
inline
MIGRAPHX_HIP_HOST_DEVICE
bool
is_inf
()
const
{
return
data
==
0x80
;
}
// assignment overloading only from the same F8 types
inline
__host__
__device__
migraphx_f8
&
operator
=
(
const
migraphx_f8
&
a
)
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
is_inf
()
const
{
data
=
a
.
data
;
return
*
this
;
if
constexpr
(
MIGRAPHX_FP8_FNUZ
)
{
return
data
==
0x80
;
}
else
{
if
(
T
==
migraphx_fp8
::
hip_f8_type
::
bf8
)
{
return
(
data
==
0x7c
)
||
(
data
==
0xfc
);
}
else
{
return
(
data
==
0x78
)
||
(
data
==
0xf8
);
}
}
}
};
/*
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const migraphx_f8& f8) { return os << float(f8); }
inline std::ostream& operator<<(std::ostream& os, const migraphx_bf8& bf8)
{
return os << float(bf8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator+(const float fa, migraphx_f8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(const float fa, migraphx_bf8 b)
{
return (fa + float(b));
}
inline __host__ __device__ float operator+(migraphx_f8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_bf8 a, const float fb)
{
return (float(a) + fb);
}
inline __host__ __device__ float operator+(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ float operator+(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) + float(b));
}
inline __host__ __device__ migraphx_f8 operator+(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8 operator+(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) + float(b));
}
inline __host__ __device__ migraphx_f8& operator+=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) + float(b));
}
inline __host__ __device__ migraphx_bf8& operator+=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_f8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_f8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_f8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(float a, migraphx_bf8 b) { return (a * float(b)); }
inline __host__ __device__ float operator*(migraphx_bf8 a, float b) { return (float(a) * b); }
inline __host__ __device__ float operator*(int32_t a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
inline __host__ __device__ float operator*(double a, migraphx_bf8 b)
{
return ((float)a * float(b));
}
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator*(migraphx_f8 a, migraphx_bf8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) * float(b);
}
// all - operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline __host__ __device__ float operator-(const float fa, migraphx_f8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(const float fa, migraphx_bf8 b)
{
return (fa - float(b));
}
inline __host__ __device__ float operator-(migraphx_f8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_bf8 a, const float fb)
{
return (float(a) - fb);
}
inline __host__ __device__ float operator-(migraphx_f8 a, migraphx_bf8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ float operator-(migraphx_bf8 a, migraphx_f8 b)
{
return (float(a) - float(b));
}
inline __host__ __device__ migraphx_f8 operator-(migraphx_f8 a, migraphx_f8 b)
{
return migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8 operator-(migraphx_bf8 a, migraphx_bf8 b)
{
return migraphx_bf8(float(a) - float(b));
}
inline __host__ __device__ migraphx_f8& operator-=(migraphx_f8& a, migraphx_f8 b)
{
return a = migraphx_f8(float(a) - float(b));
}
inline __host__ __device__ migraphx_bf8& operator-=(migraphx_bf8& a, migraphx_bf8 b)
{
return a = migraphx_bf8(float(a) - float(b));
}
// overloading division, always returns float,
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
inline __host__ __device__ float operator/(float a, migraphx_f8 b) { return (a / float(b)); }
inline __host__ __device__ float operator/(migraphx_f8 a, float b) { return (float(a) / b); }
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const hip_f8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
} \
constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<hip_f8>(tmp); \
return *this; \
}
inline __host__ __device__ float operator/(int32_t a, migraphx_f8 b
)
{
return ((float)a / float(b));
}
MIGRAPHX_FP8_UNARY_OP
(
*=
,
*
)
MIGRAPHX_FP8_UNARY_OP
(
-=
,
-
)
MIGRAPHX_FP8_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_UNARY_OP
(
/=
,
/
)
inline __host__ __device__ float operator/(double a, migraphx_f8 b)
{
return ((float)a / float(b));
}
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
&
operator
=
(
const
hip_f8
&
rhs
)
=
default
;
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
hip_f8
&
operator
=
(
hip_f8
&&
rhs
)
=
default
;
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_bf8 b)
{
return float(a) / float(b);
}
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline
constexpr
hip_f8
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
{
*
this
=
static_cast
<
hip_f8
>
(
rhs
);
return
*
this
;
}
#endif
inline __host__ __device__ float operator/(float a, migraphx_bf8 b) { return (a / float(b)); }
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
operator
==
(
const
hip_f8
&
rhs
)
const
{
if
((
rhs
.
is_zero
()
&&
this
->
is_zero
())
||
(
fabs
(
rhs
-
*
this
)
<
std
::
numeric_limits
<
hip_f8
<
T
>>::
epsilon
()))
return
true
;
else
if
(
rhs
.
is_nan
()
||
rhs
.
is_inf
()
||
this
->
is_nan
()
||
this
->
is_inf
())
return
false
;
inline __host__ __device__ float operator/(migraphx_bf8 a, float b) { return (float(a) / b); }
return
false
;
}
inline __host__ __device__ float operator/(int32_t a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
operator
<
(
const
hip_f8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
<
them
;
}
inline __host__ __device__ float operator/(double a, migraphx_bf8 b)
{
return ((float)a / float(b));
}
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
bool
operator
>
(
const
hip_f8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
>
them
;
}
};
// overloading for mixed f8 and bf8 types
inline __host__ __device__ float operator/(migraphx_f8 a, migraphx_bf8 b)
#ifndef __HIPCC_RTC__
// Special operator overloading
template
<
migraphx_fp8
::
hip_f8_type
T
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
migraphx_fp8
::
hip_f8
<
T
>&
rhs
)
{
return
float(a) / float(b
);
return
os
<<
static_cast
<
float
>
(
rhs
);
}
#endif
inline __host__ __device__ float operator/(migraphx_bf8 a, migraphx_f8 b)
{
return float(a) / float(b);
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \
template <migraphx_fp8::hip_f8_type T> \
inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \
const migraphx_fp8::hip_f8<T>& lhs, const migraphx_fp8::hip_f8<T>& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// overloading for compare
inline __host__ __device__ bool operator==(migraphx_f8 a, migraphx_f8 b)
{
return (a.data == b.data);
}
MIGRAPHX_FP8_BINARY_OP
(
*
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
-
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
/
,
migraphx_fp8
::
hip_f8
<
T
>
)
MIGRAPHX_FP8_BINARY_OP
(
+
,
migraphx_fp8
::
hip_f8
<
T
>
)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP
(
==
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<=
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
>
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
<
,
bool
)
MIGRAPHX_FP8_BINARY_OP
(
!=
,
bool
)
inline __host__ __device__ bool operator==(migraphx_bf8 a, migraphx_bf8 b)
template
<
migraphx_fp8
::
hip_f8_type
T
>
inline
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
T
>
fabs
(
migraphx_fp8
::
hip_f8
<
T
>
v
)
{
return (a.data == b.data);
v
.
data
=
v
.
data
&
0x7f
;
return
v
;
}
inline __host__ __device__ bool operator!=(migraphx_f8 a, migraphx_f8 b)
template
<
class
T
>
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
F8_Max
()
{
return
(a.data != b.data)
;
return
T
{
0x7F
,
T
::
from_bits
()}
;
}
inline __host__ __device__ bool operator!=(migraphx_bf8 a, migraphx_bf8 b)
template
<
class
T
>
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
F8_Lowest
()
{
return
(a.data != b.data)
;
return
T
{
0xFF
,
T
::
from_bits
()}
;
}
// ================ Explicit downcasting to support different rounding (RNE, SR)
// =============== NOTE: we going to remove all assignment operator overloading from other
// types and enforce this explicit_downcast function to make any roudning behavior default
// We have to explicitly call this function with SR flag
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<migraphx::is_same<T, Ta>{}, int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng = 0)
{
// same type, no conversion
return a;
}
using
fp8e4m3fnuz
=
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
;
/*
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
typename Ta,
...
...
@@ -578,15 +508,99 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
}
*/
}
// namespace migraphx_fp8
/
*
/
/ define numeric limits for the new data type
namespace
std
{
inline migraphx_f8 sin(migraphx_f8 a) { return migraphx_f8(sinf(float(a))); }
inline migraphx_f8 cos(migraphx_f8 a) { return migraphx_f8(cosf(float(a))); }
inline migraphx_bf8 sin(migraphx_bf8 a) { return migraphx_bf8(sinf(float(a))); }
inline migraphx_bf8 cos(migraphx_bf8 a) { return migraphx_bf8(cosf(float(a))); }
__device__ __host__ constexpr migraphx_f8 real(const migraphx_f8& a) { return a; }
__device__ __host__ constexpr migraphx_bf8 real(const migraphx_bf8& a) { return a; }
inline
bool
isfinite
(
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
x
)
// NOLINT
{
return
x
.
is_inf
();
}
inline
bool
isfinite
(
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
x
)
// NOLINT
{
return
x
.
is_inf
();
}
template
<
>
class
numeric_limits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
float
(
0.0625
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x79
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
max
()
{
return
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
(
-
1.0
f
)
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
fp8
>>
();
}
};
template
<
>
class
numeric_limits
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
{
public:
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
epsilon
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
0.125
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
quiet_NaN
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
static_cast
<
uint8_t
>
(
MIGRAPHX_FP8_FNUZ
?
0X80
:
0x7d
));
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
max
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
());
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
min
()
{
return
static_cast
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
(
float
(
-
1.0
f
))
*
migraphx_fp8
::
F8_Max
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
static
MIGRAPHX_HIP_HOST_DEVICE
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>
lowest
()
{
return
migraphx_fp8
::
F8_Lowest
<
migraphx_fp8
::
hip_f8
<
migraphx_fp8
::
hip_f8_type
::
bf8
>>
();
}
};
template
<
class
T
>
struct
common_type
<
migraphx_fp8
::
fp8e4m3fnuz
,
T
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
class
T
>
struct
common_type
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
:
std
::
common_type
<
float
,
T
>
// NOLINT
{
};
template
<
>
struct
common_type
<
migraphx_fp8
::
fp8e4m3fnuz
,
migraphx_fp8
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
}
// namespace std
*/
// =================================================================================================
#endif // MIGRAPHX_FLOAT8_HPP
#pragma clang diagnostic pop
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
0a8edad5
...
...
@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx_hip_f8_impl
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
...
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
...
...
@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
constexpr
int
weo
=
8
;
constexpr
int
wmo
=
23
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
if
(
x
==
0
)
return
0
;
...
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
typename
migraphx
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
...
src/include/migraphx/shape.hpp
View file @
0a8edad5
...
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
...
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/type_traits.hpp
View file @
0a8edad5
...
...
@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx_fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
0a8edad5
...
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
...
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
template
<
>
struct
npy_format_descriptor
<
migraphx
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
_fp8
::
fp8e4m3fnuz
>
{
static
std
::
string
format
()
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
0a8edad5
...
...
@@ -60,7 +60,7 @@ endif()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_float8.hpp
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
${
CMAKE_SOURCE_DIR
}
/src/include
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
0a8edad5
...
...
@@ -35,7 +35,7 @@ namespace migraphx {
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8e4m3fnuz
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
_fp8
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
...
...
@@ -76,17 +76,17 @@ constexpr T as_float(T x)
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs)
\
MIGRAPHX_RETURNS(
migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname)
\
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())>
\
auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, Ts... xs)
MIGRAPHX_RETURNS(
\
migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y)
\
-> migraphx::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname)
\
inline auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, migraphx
_fp8
::fp8e4m3fnuz y) \
-> migraphx
_fp8
::fp8e4m3fnuz
\
{
\
return migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y)));
\
}
// Template with two overloads for math functions, one for half2 type and one for more generic
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
0a8edad5
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
{})
>
constexpr
T
numeric_max
()
{
if
constexpr
(
is_integral
<
T
>
{})
...
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0x7F
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
_
fp8
::
F8_Max
<
T
>
();
else
return
0
;
}
...
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
return
-
numeric_max
<
T
>
()
-
1
;
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0xFF
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
_
fp8
::
F8_Lowest
<
T
>
();
else
{
return
-
numeric_max
<
T
>
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
0a8edad5
...
...
@@ -23,7 +23,7 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
0a8edad5
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include
"
migraphx/kernels/type_traits.hpp
"
#include
<
migraphx/kernels/type_traits.hpp
>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
...
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
::
fp8e4m3fnuz
>
{})
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
x
;
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
...
...
test/gpu/jit.cpp
View file @
0a8edad5
...
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
if
(
t
!=
migraphx
::
shape
::
float8_type
)
...
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
...
...
tools/api/migraphx.h
View file @
0a8edad5
...
...
@@ -37,7 +37,7 @@
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment