Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5d745540
Unverified
Commit
5d745540
authored
Feb 02, 2023
by
Brian Pickrell
Committed by
GitHub
Feb 03, 2023
Browse files
Dynamic shape support in scatterND ops (#1455)
* Implement dynamic shapes for scatterND operators.
parent
d478675c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
82 deletions
+322
-82
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+66
-21
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+18
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+56
-42
test/onnx/scatternd_dyn_test.onnx
test/onnx/scatternd_dyn_test.onnx
+0
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+137
-19
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+45
-0
No files found.
src/include/migraphx/op/scatternd_op.hpp
View file @
5d745540
...
@@ -28,44 +28,89 @@
...
@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template
<
class
Derived
>
template
<
class
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
{
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
q
=
inputs
.
at
(
1
).
lens
().
size
();
auto
index_shape
=
inputs
.
at
(
1
);
auto
k
=
inputs
.
at
(
1
).
lens
().
back
();
auto
upd_shape
=
inputs
.
back
();
auto
ind_lens
=
inputs
.
at
(
1
).
lens
();
auto
upd_lens
=
inputs
.
back
().
lens
();
auto
r
=
data_shape
.
ndim
();
auto
data_lens
=
inputs
.
front
().
lens
();
auto
q
=
index_shape
.
ndim
();
size_t
k
;
if
(
index_shape
.
dynamic
())
{
// the rank of the output is a function of k, so k must be fixed.
if
(
not
index_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
index_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
index_shape
.
lens
().
back
();
// Checks on the sizes of input tensors
if
(
q
+
r
!=
upd_shape
.
ndim
()
+
k
+
1
)
MIGRAPHX_THROW
(
"ScatterND: ranks of inputs don't match. "
+
std
::
to_string
(
q
)
+
" + "
+
std
::
to_string
(
r
)
+
" - "
+
std
::
to_string
(
k
)
+
" - 1 != "
+
std
::
to_string
(
upd_shape
.
ndim
()));
if
(
k
>
r
)
if
(
k
>
r
)
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
if
(
not
(
std
::
equal
(
ind_lens
.
begin
(),
ind_lens
.
begin
()
+
q
-
1
,
upd_lens
.
begin
())
and
std
::
equal
(
data_lens
.
begin
()
+
k
,
data_lens
.
end
(),
upd_lens
.
begin
()
+
q
-
1
)))
// Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"
);
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto
s
=
inputs
.
front
();
auto
ind_dims
=
index_shape
.
to_dynamic
().
dyn_dims
();
if
(
s
.
broadcasted
())
auto
upd_dims
=
upd_shape
.
to_dynamic
().
dyn_dims
();
auto
data_dims
=
data_shape
.
to_dynamic
().
dyn_dims
();
// Check that corresponding portions of tensor shapes match.
if
(
not
(
std
::
equal
(
ind_dims
.
begin
(),
ind_dims
.
begin
()
+
q
-
1
,
upd_dims
.
begin
())
and
std
::
equal
(
data_dims
.
begin
()
+
k
,
data_dims
.
end
(),
upd_dims
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. Update dimensions must match "
"indices and data."
);
if
(
data_shape
.
dynamic
())
return
data_shape
;
else
if
(
data_shape
.
broadcasted
())
{
{
return
{
s
.
type
(),
s
.
lens
()};
return
{
data_shape
.
type
(),
data_shape
.
lens
()};
}
}
else
else
{
{
return
s
.
with_lens
(
s
.
lens
());
return
data_shape
.
with_lens
(
data_shape
.
lens
());
}
}
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
lens
().
size
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
out
put_shape
.
lens
().
size
();
auto
r
=
dyn_out
.
com
put
ed
_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
out
put_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
com
put
ed
_shape
.
index
(
out_idx
)],
updates
[
i
]);
});
});
});
});
});
});
...
...
test/onnx/gen_onnx.py
View file @
5d745540
...
@@ -5968,6 +5968,24 @@ def scatternd_test():
...
@@ -5968,6 +5968,24 @@ def scatternd_test():
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_dyn_test
():
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
None
,
1
,
2
])
updates
=
helper
.
make_tensor_value_info
(
'updates'
,
TensorProto
.
FLOAT
,
[
None
,
1
,
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
None
,
2
,
2
])
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
@
onnx_test
()
def
selu_test
():
def
selu_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
DOUBLE
,
[
2
,
3
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
DOUBLE
,
[
2
,
3
])
...
...
test/onnx/onnx_test.cpp
View file @
5d745540
...
@@ -5768,53 +5768,67 @@ TEST_CASE(scatter_none_test)
...
@@ -5768,53 +5768,67 @@ TEST_CASE(scatter_none_test)
TEST_CASE
(
scatternd_test
)
TEST_CASE
(
scatternd_test
)
{
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l0
=
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l2
=
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto
l1
=
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
l0
,
l1
,
l2
);
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
mm
->
add_return
({
r
});
auto
l2
=
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_test.onnx"
);
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
l0
,
l1
,
l2
);
mm
->
add_return
({
r
});
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
{
TEST_CASE
(
scatternd_dyn_test
)
migraphx
::
program
p
;
{
auto
*
mm
=
p
.
get_main_module
();
// dynamic input.
auto
l0
=
migraphx
::
program
p
;
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
*
mm
=
p
.
get_main_module
();
auto
l1
=
// parameters with dynamic dimensions
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
auto
l0
=
mm
->
add_parameter
(
auto
l2
=
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
3
,
2
},
{
2
,
2
},
{
2
,
2
}}});
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto
l1
=
mm
->
add_parameter
(
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
l0
,
l1
,
l2
);
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{{
2
,
1
,
2
},
{
1
,
1
},
{
2
,
2
}}});
mm
->
add_return
({
r
});
auto
l2
=
mm
->
add_parameter
(
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_add_test.onnx"
);
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
1
,
2
},
{
1
,
1
},
{
2
,
2
}}});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
l0
,
l1
,
l2
);
mm
->
add_return
({
r
});
migraphx
::
onnx_options
options
;
options
.
map_dyn_input_dims
[
"data"
]
=
{{
1
,
3
,
2
},
{
2
,
2
},
{
2
,
2
}};
options
.
map_dyn_input_dims
[
"indices"
]
=
{{
2
,
1
,
2
},
{
1
,
1
},
{
2
,
2
}};
options
.
map_dyn_input_dims
[
"updates"
]
=
{{
2
,
1
,
2
},
{
1
,
1
},
{
2
,
2
}};
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_dyn_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
{
TEST_CASE
(
scatternd_add_test
)
migraphx
::
program
p
;
{
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
program
p
;
auto
l0
=
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l1
=
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
auto
l2
=
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto
l2
=
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
l0
,
l1
,
l2
);
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
mm
->
add_return
({
r
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
l0
,
l1
,
l2
);
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_add_test.onnx"
);
mm
->
add_return
({
r
});
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_mul_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
scatternd_mul_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
auto
l2
=
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
l0
,
l1
,
l2
);
mm
->
add_return
({
r
});
auto
prog
=
migraphx
::
parse_onnx
(
"scatternd_mul_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
selu_test
)
TEST_CASE
(
selu_test
)
...
...
test/onnx/scatternd_dyn_test.onnx
0 → 100644
View file @
5d745540
File added
test/op_shape_test.cpp
View file @
5d745540
...
@@ -2691,27 +2691,145 @@ TEST_CASE(test_gathernd_dynamic8)
...
@@ -2691,27 +2691,145 @@ TEST_CASE(test_gathernd_dynamic8)
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
}
TEST_CASE
(
test_scatternd
)
TEST_CASE
(
test_scatternd
0
)
{
{
{
// good
// k > r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
is
{
itype
,
{
4
,
2
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
expect_shape
(
ds
,
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
}
{
TEST_CASE
(
test_scatternd1
)
// update.lens != indices.lens[0:q-1] ++ data.lens[k:r-1]
{
auto
dtype
=
migraphx
::
shape
::
float_type
;
// good, broadcasted
auto
itype
=
migraphx
::
shape
::
int64_type
;
auto
dtype
=
migraphx
::
shape
::
float_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
2
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
},
{
4
,
0
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
migraphx
::
shape
us
{
dtype
,
{
4
}};
}
expect_shape
(
ds
,
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd2
)
{
// too many inputs
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
migraphx
::
shape
zs
{
dtype
,
{
4
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
,
zs
);
}
TEST_CASE
(
test_scatternd3
)
{
// q + r - k - 1 matches upd_lens.size(), but k > r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
5
,
4
,
2
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd4
)
{
// q + r - k - 1 != upd_lens.size()
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd5
)
{
// dimensions don't match: update.lens != indices.lens[0:q-1]
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
,
3
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
2
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn0
)
{
// one dynamic input, invalid index
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
4
}};
migraphx
::
shape
is
{
itype
,
{
4
,
13
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
4
,
0
};
migraphx
::
shape
us
{
dtype
,
{
dd
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn1
)
{
// one dynamic input
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
4
,
0
};
migraphx
::
shape
us
{
dtype
,
{
dd
}};
expect_shape
(
ds
,
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn2
)
{
// one dynamic input and broadcasted data
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
3
,
1
,
4
},
{
0
,
1
,
1
,
0
}};
migraphx
::
shape
ds_std
{
dtype
,
{
2
,
3
,
1
,
4
}};
migraphx
::
shape
is
{
itype
,
{
4
,
4
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
4
,
0
};
migraphx
::
shape
us
{
dtype
,
{
dd
}};
expect_shape
(
ds_std
,
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn3
)
{
// one dynamic input and standard, static data
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
3
,
1
,
4
}};
migraphx
::
shape
is
{
itype
,
{
4
,
4
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
4
,
0
};
migraphx
::
shape
us
{
dtype
,
{
dd
}};
expect_shape
(
ds
,
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn4
)
{
// index is dynamic with last dimension not fixed
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
3
,
1
,
4
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
5
,
0
};
migraphx
::
shape
is
{
itype
,
{
dd
,
dd
}};
migraphx
::
shape
us
{
dtype
,
{
dd
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
TEST_CASE
(
test_scatternd_dyn5
)
{
// dimensions don't match: update.lens != indices.lens[0:q-1]
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
2
,
3
,
1
,
4
}};
migraphx
::
shape
::
dynamic_dimension
dd
{
4
,
4
,
0
};
migraphx
::
shape
::
dynamic_dimension
dbad
{
2
,
3
,
0
};
migraphx
::
shape
is
{
itype
,
{
dd
,
dd
}};
migraphx
::
shape
us
{
dtype
,
{
dbad
}};
throws_shape
(
migraphx
::
make_op
(
"scatternd_none"
),
ds
,
is
,
us
);
}
}
TEST_CASE
(
test_squeeze
)
TEST_CASE
(
test_squeeze
)
...
...
test/ref_ops_test.cpp
View file @
5d745540
...
@@ -7242,6 +7242,51 @@ TEST_CASE(scatternd_reduction_test)
...
@@ -7242,6 +7242,51 @@ TEST_CASE(scatternd_reduction_test)
}
}
}
}
TEST_CASE(scatternd_reduction_dyn_test)
{
// reduction = add, with dynamic input shapes
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape::dynamic_dimension dd{3, 6, 0};
migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {{2, 2, 0}, dd, dd}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto xupdates = mm->add_parameter("U", us);
auto scatternd_add_op = migraphx::make_op("scatternd_add");
auto scatternd = mm->add_instruction(scatternd_add_op, xdata, xindex, xupdates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 4, 4}}; // data
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<uint64_t> input_index{0, 2};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {2, 4, 4}}; // updates
std::vector<float> input_updates{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
params["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
params["I"] = migraphx::argument(is, input_index.data());
params["U"] = migraphx::argument(input_fixed_shape1, input_updates.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{6, 7, 8, 9, 11, 12, 13, 14, 15, 14, 13, 12, 12, 11, 10, 9,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sigmoid_test)
TEST_CASE(sigmoid_test)
{
{
migraphx::program p;
migraphx::program p;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment