Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
04e01f74
Commit
04e01f74
authored
Jan 17, 2019
by
Shucai Xiao
Browse files
fixed a comments and add two more tests for gather.
parent
1bdd55e8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
15 deletions
+33
-15
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+7
-7
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+8
-8
test/op_shape_test.cpp
test/op_shape_test.cpp
+18
-0
No files found.
src/include/migraphx/operators.hpp
View file @
04e01f74
...
...
@@ -653,17 +653,13 @@ struct gather
}
template
<
class
T
>
void
compute_index
(
const
T
&
out_idx
,
const
std
::
vector
<
argument
>&
args
,
T
&
in_idx
)
const
void
compute_index
(
const
T
&
out_idx
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
];
std
::
vector
<
std
::
size_t
>
vec_indices
(
args
[
1
].
get_shape
().
lens
().
size
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis
]);
if
(
idx
>=
max_dim
)
{
MIGRAPHX_THROW
(
"Gather
,
indices are out of range in input tensor"
);
MIGRAPHX_THROW
(
"Gather
:
indices are out of range in input tensor"
);
}
in_idx
[
axis
]
=
idx
;
}
...
...
@@ -671,10 +667,14 @@ struct gather
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis
];
std
::
vector
<
std
::
size_t
>
vec_indices
;
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
this
->
compute_index
(
idx
,
args
,
in_idx
);
this
->
compute_index
(
idx
,
vec_indices
,
max_dim
,
in_idx
);
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
...
...
src/onnx/onnx.cpp
View file @
04e01f74
...
...
@@ -546,7 +546,7 @@ struct onnx_parser
parse_shape
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Shape
,
operator should have 1 operand"
);
MIGRAPHX_THROW
(
"Shape
:
operator should have 1 operand"
);
std
::
vector
<
std
::
size_t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
...
...
@@ -585,26 +585,26 @@ struct onnx_parser
if
(
contains
(
attributes
,
"extra_shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill
,
cannot handle extra shape attribute"
);
MIGRAPHX_THROW
(
"ConstantFill
:
cannot handle extra shape attribute"
);
}
if
(
input_as_shape
==
1
)
{
if
(
args
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"ConstantFill
,
need an input argument as output shape"
);
MIGRAPHX_THROW
(
"ConstantFill
:
need an input argument as output shape"
);
}
if
(
contains
(
attributes
,
"shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill
,
cannot set the shape argument and pass in an input "
MIGRAPHX_THROW
(
"ConstantFill
:
cannot set the shape argument and pass in an input "
"at the same time"
);
}
migraphx
::
argument
in
=
args
[
0
]
->
eval
();
if
(
in
.
empty
())
{
MIGRAPHX_THROW
(
"ConstantFill
,
cannot handle dynamic shape as input"
);
MIGRAPHX_THROW
(
"ConstantFill
:
cannot handle dynamic shape as input"
);
}
std
::
vector
<
std
::
size_t
>
dims
;
...
...
@@ -617,11 +617,11 @@ struct onnx_parser
{
if
(
!
contains
(
attributes
,
"shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill
,
attribute output shape is needed"
);
MIGRAPHX_THROW
(
"ConstantFill
:
attribute output shape is needed"
);
}
literal
ls
=
parse_value
(
attributes
.
at
(
"shape"
));
std
::
vector
<
std
::
size_t
>
dims
(
ls
.
get_shape
().
elements
())
;
std
::
vector
<
std
::
size_t
>
dims
;
ls
.
visit
([
&
](
auto
s
)
{
dims
.
assign
(
s
.
begin
(),
s
.
end
());
});
migraphx
::
shape
s
{
type
,
dims
};
std
::
vector
<
float
>
values
(
s
.
elements
(),
value
);
...
...
@@ -629,7 +629,7 @@ struct onnx_parser
}
else
{
MIGRAPHX_THROW
(
"ConstantFill
,
wrong value of attribute input_as_shape"
);
MIGRAPHX_THROW
(
"ConstantFill
:
wrong value of attribute input_as_shape"
);
}
}
...
...
test/op_shape_test.cpp
View file @
04e01f74
...
...
@@ -212,4 +212,22 @@ TEST_CASE(multibroadcast)
}
}
TEST_CASE
(
gather
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
std
::
size_t
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
6
,
4
,
5
}},
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
std
::
size_t
axis
=
4
;
throws_shape
(
migraphx
::
op
::
gather
{
axis
},
input
,
indices
);
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment