Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a8edad5
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0ca5e1ce8b2409a2a20ad0ea829d8b02f0cf321d"
Commit
0a8edad5
authored
Nov 08, 2023
by
Umang Yadav
Browse files
works except constexpr
parent
d734871c
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
388 additions
and
384 deletions
+388
-384
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+3
-3
src/include/migraphx/migraphx_float8.hpp
src/include/migraphx/migraphx_float8.hpp
+325
-311
src/include/migraphx/migraphx_hip_f8_impl.hpp
src/include/migraphx/migraphx_hip_f8_impl.hpp
+28
-39
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-2
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+4
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+2
-2
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+10
-10
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+2
-2
test/gpu/jit.cpp
test/gpu/jit.cpp
+2
-2
tools/api/migraphx.h
tools/api/migraphx.h
+1
-1
No files found.
src/include/migraphx/half.hpp
View file @
0a8edad5
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <half/half.hpp>
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
...
@@ -69,13 +69,13 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
struct
common_type
<
migraphx
_fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
template
<
>
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
struct
common_type
<
migraphx
::
half
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{
{
using
type
=
float
;
using
type
=
float
;
};
};
...
...
src/include/migraphx/migraphx_float8.hpp
View file @
0a8edad5
This diff is collapsed.
Click to expand it.
src/include/migraphx/migraphx_hip_f8_impl.hpp
View file @
0a8edad5
...
@@ -25,8 +25,22 @@
...
@@ -25,8 +25,22 @@
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx_hip_f8_impl
{
namespace
migraphx_hip_f8_impl
{
namespace
detail
{
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace detail
// #ifdef __HIP_PLATFORM_HCC__
// #ifdef __HIP_PLATFORM_HCC__
// __device__ inline int clz(uint32_t x) { return __clz(x); }
// __device__ inline int clz(uint32_t x) { return __clz(x); }
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
...
@@ -35,12 +49,10 @@ namespace migraphx_hip_f8_impl {
// #endif
// #endif
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
MIGRAPHX_HIP_HOST_DEVICE
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
uint8_t
cast_to_f8
(
T
_x
,
bool
stoch
,
uint32_t
rng
)
{
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
uint32_t
x
;
...
@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */
...
@@ -215,29 +227,12 @@ this case, the fp16 mantissa should be shift left by 1 */
}
}
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
template
<
int
wm
,
int
we
,
typename
T
,
bool
negative_zero_nan
>
MIGRAPHX_HIP_HOST_DEVICE
T
cast_from_f8
(
uint8_t
x
)
MIGRAPHX_HIP_HOST_DEVICE
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
{
constexpr
bool
is_half
=
migraphx
::
is_same
<
T
,
migraphx
::
half
>
{};
constexpr
int
weo
=
8
;
constexpr
bool
is_float
=
migraphx
::
is_same
<
T
,
float
>
{};
constexpr
int
wmo
=
23
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
migraphx
::
half
&>
(
ihNeg0
);
}
else
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNaN
=
0x7F800001
;
...
@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
...
@@ -246,7 +241,6 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
if
(
x
==
0
)
return
0
;
return
0
;
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
...
@@ -266,12 +260,7 @@ MIGRAPHX_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
if
(
exponent
==
((
1
<<
we
)
-
1
))
if
(
exponent
==
((
1
<<
we
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
typename
migraphx
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
typename
detail
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
...
...
src/include/migraphx/shape.hpp
View file @
0a8edad5
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
...
@@ -54,7 +54,7 @@ struct MIGRAPHX_EXPORT shape
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
src/include/migraphx/type_traits.hpp
View file @
0a8edad5
...
@@ -25,10 +25,10 @@
...
@@ -25,10 +25,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/migraphx_float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
...
@@ -63,9 +63,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx_fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx_fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
...
...
src/py/migraphx_py.cpp
View file @
0a8edad5
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
#endif
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
...
@@ -145,7 +145,7 @@ struct npy_format_descriptor<half>
};
};
template
<
>
template
<
>
struct
npy_format_descriptor
<
migraphx
::
fp8e4m3fnuz
>
struct
npy_format_descriptor
<
migraphx
_fp8
::
fp8e4m3fnuz
>
{
{
static
std
::
string
format
()
static
std
::
string
format
()
{
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
0a8edad5
...
@@ -60,7 +60,7 @@ endif()
...
@@ -60,7 +60,7 @@ endif()
include
(
Embed
)
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
fp8e4m3fnuz.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/ EXTRA_HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/
migraphx_float8.hpp
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/migraphx_hip_f8_impl.hpp EXTRA_HEADERS_RELATIVE
${
CMAKE_SOURCE_DIR
}
/src/include
${
CMAKE_SOURCE_DIR
}
/src/include
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
file
(
GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/device/*.cpp
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
0a8edad5
...
@@ -35,7 +35,7 @@ namespace migraphx {
...
@@ -35,7 +35,7 @@ namespace migraphx {
namespace
math
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8e4m3fnuz
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
_fp8
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
constexpr
T
as_float
(
T
x
)
...
@@ -78,15 +78,15 @@ constexpr T as_float(T x)
...
@@ -78,15 +78,15 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs)
\
auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, Ts... xs)
MIGRAPHX_RETURNS(
\
MIGRAPHX_RETURNS(
migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y)
\
inline auto __device__ name(migraphx
_fp8
::fp8e4m3fnuz x, migraphx
_fp8
::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
-> migraphx
_fp8
::fp8e4m3fnuz
\
{ \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
return migraphx
_fp8
::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y)));
\
}
}
// Template with two overloads for math functions, one for half2 type and one for more generic
// Template with two overloads for math functions, one for half2 type and one for more generic
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
0a8edad5
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
...
@@ -231,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx_fp8
::
fp8e4m3fnuz
>
{})
>
constexpr
T
numeric_max
()
constexpr
T
numeric_max
()
{
{
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
...
@@ -247,8 +248,8 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
return
__FLT16_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
T
{
0x7F
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
return
migraphx
_
fp8
::
F8_Max
<
T
>
();
else
else
return
0
;
return
0
;
}
}
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
...
@@ -263,8 +264,8 @@ constexpr T numeric_lowest()
else
else
return
-
numeric_max
<
T
>
()
-
1
;
return
-
numeric_max
<
T
>
()
-
1
;
}
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
else
if
constexpr
(
is_same
<
T
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
T
{
0xFF
,
migraphx
::
fp8
e4m3fnuz
::
from_bits
()
}
;
return
migraphx
_
fp8
::
F8_Lowest
<
T
>
();
else
else
{
{
return
-
numeric_max
<
T
>
();
return
-
numeric_max
<
T
>
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
0a8edad5
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/
fp8e4m3fnuz
.hpp>
#include <migraphx/
migraphx_float8
.hpp>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
0a8edad5
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include
"
migraphx/kernels/type_traits.hpp
"
#include
<
migraphx/kernels/type_traits.hpp
>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/vec.hpp>
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
...
@@ -237,7 +237,7 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
::
fp8e4m3fnuz
>
{})
if
constexpr
(
is_same
<
typename
T
::
type
,
migraphx
_fp8
::
fp8e4m3fnuz
>
{})
return
x
;
return
x
;
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
else
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
return
x
;
...
...
test/gpu/jit.cpp
View file @
0a8edad5
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
...
@@ -351,7 +351,7 @@ TEST_CASE(compile_math)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
data_types
.
push_back
(
name
);
if
(
t
!=
migraphx
::
shape
::
float8_type
)
if
(
t
!=
migraphx
::
shape
::
float8_type
)
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
...
@@ -402,7 +402,7 @@ TEST_CASE(assert_type_min_max)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
...
...
tools/api/migraphx.h
View file @
0a8edad5
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
m(half_type, half) \
m(half_type, half) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
m(float8_type, fp8e4m3fnuz) \
m(float8_type,
migraphx_fp8::
fp8e4m3fnuz) \
m(uint8_type, uint8_t) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(uint16_type, uint16_t) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment