Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d9d5d1c
Unverified
Commit
8d9d5d1c
authored
May 19, 2023
by
Zhuoran Yin
Committed by
GitHub
May 19, 2023
Browse files
Enabling native int32 type support (#1721)
Co-authored-by:
Paul Fultz II
<
pfultz2@yahoo.com
>
parent
3557ce90
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
21 deletions
+104
-21
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+4
-0
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+16
-12
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+11
-3
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
test/gpu/jit.cpp
test/gpu/jit.cpp
+65
-0
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
8d9d5d1c
...
@@ -94,6 +94,10 @@ template <>
...
@@ -94,6 +94,10 @@ template <>
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
{
{
};
};
template
<
>
struct
is_hip_type
<
std
::
int32_t
>
:
std
::
true_type
{
};
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
...
...
src/targets/gpu/device/scatter.cpp
View file @
8d9d5d1c
...
@@ -37,22 +37,26 @@ argument scatter(
...
@@ -37,22 +37,26 @@ argument scatter(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
{
auto
ds
=
arg0
.
get_shape
();
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
hip_visit_all
(
result
,
arg0
,
arg2
)([
&
](
auto
output
,
auto
data
,
auto
update
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
hip_visit_all
(
arg1
)([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
if
constexpr
(
indices
.
get_shape
().
lens
.
size
()
==
output
.
get_shape
().
lens
.
size
())
gs_launch
(
stream
,
inds
.
elements
())([
=
](
auto
i
)
__device__
{
{
auto
out_idx
=
s1
.
multi
(
i
);
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
auto
index
=
indices_ptr
[
i
];
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
gs_launch
(
stream
,
s1
.
elements
())([
=
](
auto
i
)
__device__
{
out_idx
[
axis
]
=
index
;
auto
out_idx
=
indices
.
get_shape
().
multi
(
i
);
output
[
out_idx
]
=
upd_ptr
[
i
];
auto
index
=
indices_ptr
[
i
];
});
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
}
});
});
});
});
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
8d9d5d1c
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix
)
\
#define MIGRAPHX_DPP_REDUCE(op, prefix
, sign)
\
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
_u32);
\
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
sign##32);
\
} \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
)
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
8d9d5d1c
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
// Note, left shift cannot be used to get the maximum value of int64_type or
// uint64_type because it is undefined behavior to left shift 64 bits for
// these types
if
(
n
==
sizeof
(
int64_t
))
return
-
1
;
return
(
1ul
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
,
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
{
{
if
constexpr
(
is_unsigned
<
T
>
{})
if
constexpr
(
is_unsigned
<
T
>
{})
return
int_max
(
sizeof
(
T
))
*
2
;
else
return
int_max
(
sizeof
(
T
));
return
int_max
(
sizeof
(
T
));
else
return
int_max
(
sizeof
(
T
))
/
2
;
}
}
else
if
constexpr
(
is_same
<
T
,
double
>
{})
else
if
constexpr
(
is_same
<
T
,
double
>
{})
return
__DBL_MAX__
;
return
__DBL_MAX__
;
...
...
src/targets/gpu/target.cpp
View file @
8d9d5d1c
...
@@ -97,6 +97,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -97,6 +97,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// clang-format off
// clang-format off
return
return
...
...
test/gpu/jit.cpp
View file @
8d9d5d1c
...
@@ -364,4 +364,69 @@ TEST_CASE(compile_math)
...
@@ -364,4 +364,69 @@ TEST_CASE(compile_math)
});
});
}
}
// NOLINTNEXTLINE
const
std
::
string
assert_template
=
R"__migraphx__(
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp>
using namespace migraphx;
extern "C" {
__global__ void kernel(void*)
{
static_assert(numeric_max<${type}>() == ${max}, "");
static_assert(numeric_lowest<${type}>() == ${min}, "");
}
}
int main() {}
)__migraphx__"
;
TEST_CASE
(
assert_type_min_max
)
{
std
::
vector
<
std
::
string
>
data_types
;
migraphx
::
gpu
::
hip_compile_options
options
;
for
(
auto
&&
t
:
migraphx
::
shape
::
types
())
{
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
std
::
string
min
=
""
;
std
::
string
max
=
""
;
// Note 9223372036854775808 is a constant literal that is outside the range of long
// long type For the same reason, 18446744073709551616 needs postfix ULL to be parsed
// correctly
if
(
t
==
migraphx
::
shape
::
int64_type
)
{
min
=
"("
+
std
::
to_string
(
as
.
min
()
+
1
)
+
"LL - 1)"
;
max
=
std
::
to_string
(
as
.
max
());
}
else
if
(
t
==
migraphx
::
shape
::
uint64_type
)
{
min
=
std
::
to_string
(
as
.
min
());
max
=
std
::
to_string
(
as
.
max
())
+
"ULL"
;
}
else
{
min
=
std
::
to_string
(
as
.
min
());
max
=
std
::
to_string
(
as
.
max
());
}
auto
src
=
migraphx
::
interpolate_string
(
assert_template
,
{{
"type"
,
name
},
{
"max"
,
max
},
{
"min"
,
min
}});
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
options
.
global
=
1024
;
options
.
local
=
1024
;
options
.
inputs
=
{
input
};
options
.
output
=
input
;
options
.
params
=
"-Wno-float-equal"
;
auto
co
=
migraphx
::
gpu
::
compile_hip_code_object
(
src
,
options
);
});
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment