Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a8edad5
Commit
0a8edad5
authored
Nov 08, 2023
by
Umang Yadav
Browse files
works except constexpr
parent
d734871c
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
388 additions
and
384 deletions
+388
-384
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+325
-311
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+28
-39
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+10
-10
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+2
-2
test/gpu/jit.cpp
test/gpu/jit.cpp
+2
-2
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/include/migraphx/half.hpp
View file @
0a8edad5
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
_fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
...
...
src/include/migraphx/migraphx_float8.hpp
View file @
0a8edad5
This diff is collapsed.
Click to expand it.
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
0a8edad5
...
@@ -25,8 +25,22 @@
...
@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx_hip_f8_impl
{
namespace
migraphx_hip_f8_impl
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
// __device__ inline int clz(uint32_t x) { return __clz(x); }
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
// #endif
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
uint32_t
x
;
...
@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */
...
@@ -215,38 +227,20 @@ this case, the fp16 mantissa should be shift left by 1 */
}
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
int
weo
=
8
;
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
constexpr
int
wmo
=
23
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
const
uint32_t
ifInf
=
0x7F800000
;
{
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint16_t
ihInf
=
0x7C00
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint32_t
ifNeg0
=
0x80000000
;
const
uint16_t
ihNaN
=
0x7C01
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
const
uint16_t
ihNeg0
=
0x8000
;
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNegInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNegInf
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
fNaN
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if
(
exponent
==
((
1
<<
we
)
-
1
))
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
typename
migraphx
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
...
src/include/migraphx/shape.hpp
View file @
0a8edad5
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/type_traits.hpp
View file @
0a8edad5
...
@@ -25,10 +25,10 @@
...
@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx_fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
0a8edad5
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
#endif
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
};
template
<
>
template
<
>
struct
npy_format_descriptor
<
migraphx
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
_fp8
::
fp8e4m3fnuz
>
{
{
static
std
::
string
format
()
static
std
::
string
format
()
{
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
0a8edad5
...
@@ -60,7 +60,7 @@ endif()
...
@@ -60,7 +60,7 @@ endif()
include
(
Embed
)
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_float8.hpp
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
${
CMAKE_SOURCE_DIR
}
/src/include
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
0a8edad5
...
@@ -35,7 +35,7 @@ namespace migraphx {
...
@@ -35,7 +35,7 @@ namespace migraphx {
namespace
math
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8e4m3fnuz
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
_fp8
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
constexpr
T
as_float
(
T
x
)
...
@@ -76,17 +76,17 @@ constexpr T as_float(T x)
...
@@ -76,17 +76,17 @@ constexpr T as_float(T x)
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname)
\
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())>
\
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs)
\
auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, Ts... xs)
MIGRAPHX_RETURNS(
\
MIGRAPHX_RETURNS(
migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname)
\
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y)
\
inline auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, migraphx
_fp8
::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
-> migraphx
_fp8
::fp8e4m3fnuz
\
{ \
{
\
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
return migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y)));
\
}
}
// Template with two overloads for math functions, one for half2 type and one for more generic
// Template with two overloads for math functions, one for half2 type and one for more generic
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
0a8edad5
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
{})
>
constexpr
T
numeric_max
()
constexpr
T
numeric_max
()
{
{
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
return
__FLT16_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
T
{
0x7F
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
return
migraphx
_
fp8
::
F8_Max
<
T
>
();
else
else
return
0
;
return
0
;
}
}
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
else
return
-
numeric_max
<
T
>
()
-
1
;
return
-
numeric_max
<
T
>
()
-
1
;
}
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
T
{
0xFF
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
return
migraphx
_
fp8
::
F8_Lowest
<
T
>
();
else
else
{
{
return
-
numeric_max
<
T
>
();
return
-
numeric_max
<
T
>
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
0a8edad5
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
0a8edad5
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include
"
migraphx/kernels/type_traits.hpp
"
#include
<
migraphx/kernels/type_traits.hpp
>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/vec.hpp>
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
::
fp8e4m3fnuz
>
{})
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
x
;
return
x
;
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
return
x
;
...
...
test/gpu/jit.cpp
View file @
0a8edad5
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
data_types
.
push_back
(
name
);
if
(
t
!=
migraphx
::
shape
::
float8_type
)
if
(
t
!=
migraphx
::
shape
::
float8_type
)
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
...
...
tools/api/migraphx.h
View file @
0a8edad5
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment