Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e29e613f
Commit
e29e613f
authored
Jun 19, 2019
by
Paul
Browse files
Fix gather args
parent
aa9863b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
14 deletions
+16
-14
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+12
-11
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
+3
-2
No files found.
src/targets/gpu/device/gather.cpp
View file @
e29e613f
...
...
@@ -12,21 +12,22 @@ namespace gpu {
namespace
device
{
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
)
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
args
[
0
].
get_shape
().
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
arg1
.
get_shape
().
lens
().
size
())
:
axis
;
auto
&
input_shape
=
arg1
.
get_shape
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
out_ptr
=
device_cast
(
output
.
data
());
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
auto
&
input_shape
=
args
[
0
].
get_shape
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
args
[
1
].
get_shape
().
elements
();
migraphx
::
shape
out_comp_shape
{
output_shape
.
type
(),
lens
};
migraphx
::
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
...
...
@@ -39,7 +40,7 @@ argument gather(hipStream_t stream,
});
});
return
args
.
back
()
;
return
result
;
}
}
// namespace device
...
...
src/targets/gpu/gather.cpp
View file @
e29e613f
...
...
@@ -16,7 +16,7 @@ argument hip_gather::compute(context& ctx,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]
,
args
[
1
]
,
op
.
axis
);
}
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
View file @
e29e613f
...
...
@@ -11,8 +11,9 @@ namespace gpu {
namespace
device
{
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
);
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment