Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a8edad5
Commit
0a8edad5
authored
Nov 08, 2023
by
Umang Yadav
Browse files
works except constexpr
parent
d734871c
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
388 additions
and
384 deletions
+388
-384
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+325
-311
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+28
-39
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+10
-10
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+2
-2
test/gpu/jit.cpp
test/gpu/jit.cpp
+2
-2
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/include/migraphx/half.hpp
View file @
0a8edad5
...
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
_fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
...
...
src/include/migraphx/migraphx_float8.hpp
View file @
0a8edad5
This diff is collapsed.
Click to expand it.
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
0a8edad5
...
...
@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx_hip_f8_impl
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
...
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
...
...
@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
constexpr
int
weo
=
8
;
constexpr
int
wmo
=
23
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
...
...
@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
return
0
;
...
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
typename
migraphx
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
...
src/include/migraphx/shape.hpp
View file @
0a8edad5
...
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
...
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/type_traits.hpp
View file @
0a8edad5
...
...
@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx_fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
0a8edad5
...
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
...
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
template
<
>
struct
npy_format_descriptor
<
migraphx
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
_fp8
::
fp8e4m3fnuz
>
{
static
std
::
string
format
()
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
0a8edad5
...
...
@@ -60,7 +60,7 @@ endif()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_float8.hpp
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
${
CMAKE_SOURCE_DIR
}
/src/include
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
0a8edad5
...
...
@@ -35,7 +35,7 @@ namespace migraphx {
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8e4m3fnuz
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
_fp8
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
...
...
@@ -78,15 +78,15 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs)
\
MIGRAPHX_RETURNS(
migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, Ts... xs)
MIGRAPHX_RETURNS(
\
migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y)
\
-> migraphx::fp8e4m3fnuz \
inline auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, migraphx
_fp8
::fp8e4m3fnuz y) \
-> migraphx
_fp8
::fp8e4m3fnuz
\
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
return migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y)));
\
}
// Template with two overloads for math functions, one for half2 type and one for more generic
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
0a8edad5
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
{})
>
constexpr
T
numeric_max
()
{
if
constexpr
(
is_integral
<
T
>
{})
...
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0x7F
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
_
fp8
::
F8_Max
<
T
>
();
else
return
0
;
}
...
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
return
-
numeric_max
<
T
>
()
-
1
;
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0xFF
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
migraphx
_
fp8
::
F8_Lowest
<
T
>
();
else
{
return
-
numeric_max
<
T
>
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
0a8edad5
...
...
@@ -23,7 +23,7 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
0a8edad5
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include
"
migraphx/kernels/type_traits.hpp
"
#include
<
migraphx/kernels/type_traits.hpp
>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
...
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
::
fp8e4m3fnuz
>
{})
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
x
;
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
...
...
test/gpu/jit.cpp
View file @
0a8edad5
...
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
if
(
t
!=
migraphx
::
shape
::
float8_type
)
...
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
...
...
tools/api/migraphx.h
View file @
0a8edad5
...
...
@@ -37,7 +37,7 @@
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment