Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
90c6a6c5
Commit
90c6a6c5
authored
Nov 03, 2023
by
Umang Yadav
Browse files
implicit_conversion fixed
parent
09aba405
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
13 deletions
+63
-13
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+11
-9
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+49
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+1
-1
test/gpu/jit.cpp
test/gpu/jit.cpp
+1
-1
No files found.
src/include/migraphx/fp8e4m3fnuz.hpp
View file @
90c6a6c5
...
@@ -316,20 +316,22 @@ struct alignas(1) fp8e4m3fnuz
...
@@ -316,20 +316,22 @@ struct alignas(1) fp8e4m3fnuz
{
{
}
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
const
fp8e4m3fnuz
&
rhs
)
=
default
;
#if !defined(__HIP_NO_F8_CONVERSIONS__)
// for the device kernels, this needs to be disabled since implicit_conversion op can type cast
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
fp8e4m3fnuz
&&
rhs
)
=
default
;
// any type to any other type and that results in conflicts in candidate overload resolutions.
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
x
);
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
float
rhs
)
{
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
rhs
);
return
*
this
;
return
*
this
;
}
}
#endif
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
x
);
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
const
fp8e4m3fnuz
&
rhs
)
=
default
;
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
fp8e4m3fnuz
&&
rhs
)
=
default
;
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
90c6a6c5
...
@@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" -D__HIP_NO_F8_CONVERSIONS__=1"
;
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
options
.
params
+=
" -Werror"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
90c6a6c5
...
@@ -34,6 +34,9 @@ namespace migraphx {
...
@@ -34,6 +34,9 @@ namespace migraphx {
namespace
math
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
constexpr
T
as_float
(
T
x
)
{
{
...
@@ -57,14 +60,14 @@ constexpr T as_float(T x)
...
@@ -57,14 +60,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type
\
auto __device__ name(type x, Ts... xs)
->
type \
{ \
{ \
return fname(x, xs...); \
return fname(x, xs...); \
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
inline auto __device__ name(type x, type y)
->
type { return fname(x, y); }
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
...
@@ -72,6 +75,20 @@ constexpr T as_float(T x)
...
@@ -72,6 +75,20 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8e4m3fnuz x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \
inline auto __device__ name(migraphx::fp8e4m3fnuz x, migraphx::fp8e4m3fnuz y) \
-> migraphx::fp8e4m3fnuz \
{ \
return migraphx::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \
}
// Template with two overloads for math functions, one for half2 type and one for more generic
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
// <half, N> vectorization where N is 4 or another even number.
...
@@ -158,6 +175,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
...
@@ -158,6 +175,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
fmod
,
::
fmod
)
MIGRAPHX_DEVICE_MATH_HALF
(
fmod
,
::
fmod
)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8
(
abs
,
::
abs
)
MIGRAPHX_DEVICE_MATH_FP8
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH_FP8
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH_FP8
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH_FP8
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_FP8
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_FP8
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_FP8
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_FP8
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_FP8
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_FP8
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_FP8
(
exp
,
::
exp
)
MIGRAPHX_DEVICE_MATH_FP8
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_FP8
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_FP8
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH_FP8
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_FP8
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_FP8
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_FP8
(
rsqrt
,
::
rsqrt
)
MIGRAPHX_DEVICE_MATH_FP8
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_FP8
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_FP8
(
sqrt
,
::
sqrt
)
MIGRAPHX_DEVICE_MATH_FP8
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_FP8
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_FP8
(
fmod
,
::
fmod
)
// Map math functions to hip half2 functions
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
...
@@ -191,6 +235,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
...
@@ -191,6 +235,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
__hmax
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
__hmax
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
__hmin
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
__hmin
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8
(
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8
(
min
,
::
min
)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
90c6a6c5
...
@@ -577,7 +577,7 @@ __device__ void fused_reduce(Output output, F f)
...
@@ -577,7 +577,7 @@ __device__ void fused_reduce(Output output, F f)
}
}
else
else
{
{
r
.
outer
([
&
]
{
output
[
out_idx
]
=
result
;
});
r
.
outer
([
&
]
{
output
[
out_idx
]
=
implicit_conversion
(
result
)
;
});
}
}
});
});
}
}
...
...
test/gpu/jit.cpp
View file @
90c6a6c5
...
@@ -144,7 +144,7 @@ extern "C" {
...
@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p)
__global__ void kernel(${type}* p)
{
{
auto x = *p;
auto x = *p;
*p = migraphx::${invoke};
*p =
implicit_conversion(
migraphx::${invoke}
)
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment