Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
832f28c6
Unverified
Commit
832f28c6
authored
Mar 02, 2022
by
turneram
Committed by
GitHub
Mar 02, 2022
Browse files
Add ScatterND operator (#1074)
Add onnx parser and ref and gpu implementations of ONNX op ScatterND
parent
bfedcd45
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
397 additions
and
1 deletion
+397
-1
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+309
-0
test/verify/test_scatternd.cpp
test/verify/test_scatternd.cpp
+30
-0
test/verify/test_scatternd_add.cpp
test/verify/test_scatternd_add.cpp
+30
-0
test/verify/test_scatternd_mul.cpp
test/verify/test_scatternd_mul.cpp
+28
-0
No files found.
test/py/onnx_backend_test.py
View file @
832f28c6
...
...
@@ -271,7 +271,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
exclude
(
r
'test_identity_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_negative_log_likelihood_loss_*'
)
backend_test
.
exclude
(
r
'test_scatternd_*'
)
# all reduce ops have dynamic axes inputs
backend_test
.
exclude
(
r
'test_size_cpu'
)
...
...
test/ref_ops_test.cpp
View file @
832f28c6
...
...
@@ -4255,6 +4255,315 @@ TEST_CASE(scatter_test)
}
}
TEST_CASE(scatternd_shapes_test)
{
{
// broadcasted input
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8}}}),
mm->add_literal(migraphx::literal{0.0f}));
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// non-standard shape input
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 6};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto td =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), data);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), td, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 2, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// non-standard updates shape
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {1, 2}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 10};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto tu =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), updates);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, tu);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(scatternd_test)
{
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 6};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 3, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 10};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// r=5, q=1, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2, 2, 2}};
migraphx::shape is{itype, {1}};
migraphx::shape us{dtype, {2, 2, 2, 2}};
std::vector<float> data_vec(32, 1);
std::vector<int64_t> ind_vec{1};
std::vector<float> upd_vec(16, 0);
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(32, 0);
std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin());
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(scatternd_reduction_test)
{
{
// reduction = add
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {8, 1}};
migraphx::shape us{dtype, {8}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7, 4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12, -8, -9, -10, -11};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_add"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// reduction = mul
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_mul"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(sigmoid_test)
{
migraphx::program p;
...
...
test/verify/test_scatternd.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd
:
verify_program
<
test_scatternd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
1
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
ld
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
{
1
}});
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
8
}}}),
ld
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_add.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_add
:
verify_program
<
test_scatternd_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
1
,
4
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
t_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
indices
);
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
data
,
t_ind
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_mul.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_mul
:
verify_program
<
test_scatternd_mul
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment