Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
97d605b8
Commit
97d605b8
authored
Jan 24, 2019
by
Shucai Xiao
Browse files
change a the gather implementation to suppport axis to be negative.
parent
f792097f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
5 deletions
+12
-5
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+9
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-1
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+2
-2
No files found.
src/include/migraphx/operators.hpp
View file @
97d605b8
...
@@ -640,17 +640,24 @@ struct as_shape
...
@@ -640,17 +640,24 @@ struct as_shape
struct
gather
struct
gather
{
{
std
::
size_
t
axis
=
0
;
mutable
in
t
axis
=
0
;
std
::
string
name
()
const
{
return
"gather"
;
}
std
::
string
name
()
const
{
return
"gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
auto
lens
=
inputs
[
0
].
lens
();
if
(
axis
>=
lens
.
size
())
if
(
axis
>=
lens
.
size
()
||
axis
<
-
lens
.
size
()
)
{
{
MIGRAPHX_THROW
(
"Gather, axis is out of range."
);
MIGRAPHX_THROW
(
"Gather, axis is out of range."
);
}
}
// negative value means counting dimensions from back
if
(
axis
<
0
)
{
axis
+=
lens
.
size
();
}
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
[
axis
]
=
inputs
[
1
].
elements
();
lens
[
axis
]
=
inputs
[
1
].
elements
();
...
...
src/onnx/onnx.cpp
View file @
97d605b8
...
@@ -377,7 +377,7 @@ struct onnx_parser
...
@@ -377,7 +377,7 @@ struct onnx_parser
instruction_ref
instruction_ref
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"axis"
))
{
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
...
...
test/onnx/onnx_test.cpp
View file @
97d605b8
...
@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
...
@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l0
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l1
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
auto
l1
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}});
std
::
size_
t
axis
=
1
;
in
t
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l0
,
l1
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l0
,
l1
);
auto
prog
=
migraphx
::
parse_onnx
(
"gather_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"gather_test.onnx"
);
...
@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
...
@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
}},
l0
->
get_shape
().
lens
());
migraphx
::
shape
const_shape
{
migraphx
::
shape
::
int32_type
,
{
1
}};
migraphx
::
shape
const_shape
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
const_shape
,
{
1
}});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
const_shape
,
{
1
}});
std
::
size_
t
axis
=
0
;
in
t
axis
=
0
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l1
,
l2
);
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l1
,
l2
);
auto
prog
=
migraphx
::
parse_onnx
(
"shape_gather.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"shape_gather.onnx"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment