Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e29e613f
Commit
e29e613f
authored
Jun 19, 2019
by
Paul
Browse files
Fix gather args
parent
aa9863b6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
14 deletions
+16
-14
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+12
-11
src/targets/gpu/gather.cpp
src/targets/gpu/gather.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
+3
-2
No files found.
src/targets/gpu/device/gather.cpp
View file @
e29e613f
...
...
@@ -12,21 +12,22 @@ namespace gpu {
namespace
device
{
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
)
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
args
[
0
].
get_shape
().
lens
().
size
())
:
axis
;
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
axis_index
=
(
axis
<
0
)
?
(
axis
+
arg1
.
get_shape
().
lens
().
size
())
:
axis
;
auto
&
input_shape
=
arg1
.
get_shape
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
arg2
.
get_shape
().
elements
();
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
out_ptr
=
device_cast
(
output
.
data
());
const
auto
*
in_ptr
=
device_cast
(
input
.
data
());
auto
&
input_shape
=
args
[
0
].
get_shape
();
auto
lens
=
input_shape
.
lens
();
lens
[
axis_index
]
=
args
[
1
].
get_shape
().
elements
();
migraphx
::
shape
out_comp_shape
{
output_shape
.
type
(),
lens
};
migraphx
::
shape
out_comp_shape
{
result
.
get_shape
().
type
(),
lens
};
visit_tensor_size
(
out_comp_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
hip_tensor_descriptor
<
n_out_dim
>
desc_input
(
input_shape
);
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
out_comp_shape
);
...
...
@@ -39,7 +40,7 @@ argument gather(hipStream_t stream,
});
});
return
args
.
back
()
;
return
result
;
}
}
// namespace device
...
...
src/targets/gpu/gather.cpp
View file @
e29e613f
...
...
@@ -16,7 +16,7 @@ argument hip_gather::compute(context& ctx,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
return
device
::
gather
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]
,
args
[
1
]
,
op
.
axis
);
}
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
View file @
e29e613f
...
...
@@ -11,8 +11,9 @@ namespace gpu {
namespace
device
{
argument
gather
(
hipStream_t
stream
,
const
migraphx
::
shape
&
output_shape
,
std
::
vector
<
migraphx
::
argument
>
args
,
argument
result
,
argument
arg1
,
argument
arg2
,
int
axis
);
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment