Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
832f28c6
Unverified
Commit
832f28c6
authored
Mar 02, 2022
by
turneram
Committed by
GitHub
Mar 02, 2022
Browse files
Add ScatterND operator (#1074)
Add onnx parser and ref and gpu implementations of ONNX op ScatterND
parent
bfedcd45
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
397 additions
and
1 deletion
+397
-1
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+309
-0
test/verify/test_scatternd.cpp
test/verify/test_scatternd.cpp
+30
-0
test/verify/test_scatternd_add.cpp
test/verify/test_scatternd_add.cpp
+30
-0
test/verify/test_scatternd_mul.cpp
test/verify/test_scatternd_mul.cpp
+28
-0
No files found.
test/py/onnx_backend_test.py
View file @
832f28c6
...
...
@@ -271,7 +271,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test
.
exclude
(
r
'test_identity_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_maxpool_2d_uint8_cpu'
)
backend_test
.
exclude
(
r
'test_negative_log_likelihood_loss_*'
)
backend_test
.
exclude
(
r
'test_scatternd_*'
)
# all reduce ops have dynamic axes inputs
backend_test
.
exclude
(
r
'test_size_cpu'
)
...
...
test/ref_ops_test.cpp
View file @
832f28c6
...
...
@@ -4255,6 +4255,315 @@ TEST_CASE(scatter_test)
}
}
TEST_CASE
(
scatternd_shapes_test
)
{
{
// broadcasted input
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
,
11
,
12
};
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
8
}}}),
mm
->
add_literal
(
migraphx
::
literal
{
0.0
f
}));
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
0
,
11
,
0
,
10
,
9
,
0
,
0
,
12
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// non-standard shape input
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
2
,
2
}};
migraphx
::
shape
us
{
dtype
,
{
2
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
ind_vec
{
0
,
0
,
0
,
1
};
std
::
vector
<
float
>
upd_vec
{
5
,
6
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
td
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
data
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
td
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
5
,
6
,
2
,
4
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// non-standard updates shape
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
2
,
1
,
3
}};
migraphx
::
shape
us
{
dtype
,
{
1
,
2
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
0
,
0
,
0
,
1
,
1
,
1
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
tu
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
updates
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
tu
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
9
,
2
,
3
,
4
,
5
,
6
,
7
,
10
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
scatternd_test
)
{
{
// r=1, q=2, k=1
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
,
11
,
12
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
11
,
3
,
10
,
9
,
6
,
7
,
12
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// r=2, q=2, k=2
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
2
,
2
}};
migraphx
::
shape
us
{
dtype
,
{
2
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
ind_vec
{
0
,
0
,
0
,
1
};
std
::
vector
<
float
>
upd_vec
{
5
,
6
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
5
,
6
,
3
,
4
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// r=3, q=3, k=3
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
2
,
1
,
3
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
1
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
0
,
0
,
0
,
1
,
1
,
1
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
9
,
2
,
3
,
4
,
5
,
6
,
7
,
10
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// r=3, q=2, k=1
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
4
,
4
,
4
}};
migraphx
::
shape
is
{
itype
,
{
2
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
4
,
4
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
0
,
2
};
std
::
vector
<
float
>
upd_vec
{
5
,
5
,
5
,
5
,
6
,
6
,
6
,
6
,
7
,
7
,
7
,
7
,
8
,
8
,
8
,
8
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
4
,
4
,
4
,
4
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
5
,
5
,
5
,
5
,
6
,
6
,
6
,
6
,
7
,
7
,
7
,
7
,
8
,
8
,
8
,
8
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
4
,
4
,
4
,
4
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// r=5, q=1, k=1
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
,
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
1
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
2
,
2
,
2
}};
std
::
vector
<
float
>
data_vec
(
32
,
1
);
std
::
vector
<
int64_t
>
ind_vec
{
1
};
std
::
vector
<
float
>
upd_vec
(
16
,
0
);
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
(
32
,
0
);
std
::
copy
(
data_vec
.
begin
(),
data_vec
.
begin
()
+
16
,
gold
.
begin
());
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
scatternd_reduction_test
)
{
{
// reduction = add
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
8
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
8
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
,
4
,
3
,
1
,
7
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
,
11
,
12
,
-
8
,
-
9
,
-
10
,
-
11
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
3
,
3
,
5
,
6
,
6
,
7
,
9
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// reduction = mul
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
float
>
data_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
std
::
vector
<
float
>
upd_vec
{
9
,
10
,
11
,
12
};
auto
data
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
data_vec
});
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_literal
(
migraphx
::
literal
{
us
,
upd_vec
});
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
22
,
3
,
40
,
45
,
6
,
7
,
96
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
sigmoid_test
)
{
migraphx
::
program
p
;
...
...
test/verify/test_scatternd.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd
:
verify_program
<
test_scatternd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
1
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
ld
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
{
1
}});
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
8
}}}),
ld
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_add.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_add
:
verify_program
<
test_scatternd_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
1
,
4
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
t_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
indices
);
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
data
,
t_ind
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_mul.cpp
0 → 100644
View file @
832f28c6
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_mul
:
verify_program
<
test_scatternd_mul
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment