Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
973cafd4
Commit
973cafd4
authored
Jun 25, 2021
by
Shucai Xiao
Browse files
gpu implementation of the scatter operator
parent
a884f396
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
127 additions
and
1 deletion
+127
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+4
-1
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+41
-0
src/targets/gpu/include/migraphx/gpu/device/scatter.hpp
src/targets/gpu/include/migraphx/gpu/device/scatter.hpp
+20
-0
src/targets/gpu/include/migraphx/gpu/scatter.hpp
src/targets/gpu/include/migraphx/gpu/scatter.hpp
+39
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-0
src/targets/gpu/scatter.cpp
src/targets/gpu/scatter.cpp
+22
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
973cafd4
...
@@ -72,6 +72,7 @@ add_library(migraphx_device
...
@@ -72,6 +72,7 @@ add_library(migraphx_device
device/rnn_variable_seq_lens.cpp
device/rnn_variable_seq_lens.cpp
device/round.cpp
device/round.cpp
device/rsqrt.cpp
device/rsqrt.cpp
device/scatter.cpp
device/sigmoid.cpp
device/sigmoid.cpp
device/sign.cpp
device/sign.cpp
device/sin.cpp
device/sin.cpp
...
@@ -144,8 +145,9 @@ add_library(migraphx_gpu
...
@@ -144,8 +145,9 @@ add_library(migraphx_gpu
reverse.cpp
reverse.cpp
rnn_variable_seq_lens.cpp
rnn_variable_seq_lens.cpp
rocblas.cpp
rocblas.cpp
s
oftmax
.cpp
s
catter
.cpp
schedule_model.cpp
schedule_model.cpp
softmax.cpp
sync_device.cpp
sync_device.cpp
target.cpp
target.cpp
write_literals.cpp
write_literals.cpp
...
@@ -202,6 +204,7 @@ register_migraphx_gpu_ops(hip_
...
@@ -202,6 +204,7 @@ register_migraphx_gpu_ops(hip_
reverse
reverse
round
round
rsqrt
rsqrt
scatter
sigmoid
sigmoid
sign
sign
sinh
sinh
...
...
src/targets/gpu/device/scatter.cpp
0 → 100644
View file @
973cafd4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/scatter.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
scatter
(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
// hip_visit_all(result, arg0, arg2, ds)([&](auto output, auto data, auto update, auto s) {
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
(),
256
)([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
gs_launch
(
stream
,
inds
.
elements
(),
256
)([
=
](
auto
i
)
__device__
{
auto
out_idx
=
s1
.
multi
(
i
);
out_idx
[
axis
]
=
indices_ptr
[
i
];
output
[
out_idx
]
=
upd_ptr
[
i
];
});
});
});
return
result
;
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/device/scatter.hpp
0 → 100644
View file @
973cafd4
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
scatter
(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/scatter.hpp
0 → 100644
View file @
973cafd4
#ifndef MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_scatter
{
op
::
scatter
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::scatter"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/lowering.cpp
View file @
973cafd4
...
@@ -172,6 +172,7 @@ struct miopen_apply
...
@@ -172,6 +172,7 @@ struct miopen_apply
add_extend_op
(
"rnn_var_sl_last_output"
);
add_extend_op
(
"rnn_var_sl_last_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"softmax"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
...
...
src/targets/gpu/scatter.cpp
0 → 100644
View file @
973cafd4
#include <migraphx/gpu/scatter.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/scatter.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_scatter
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
hip_scatter
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
scatter
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
],
args
[
1
],
args
[
2
],
op
.
axis
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment