Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d478675c
Unverified
Commit
d478675c
authored
Feb 02, 2023
by
Brian Pickrell
Committed by
GitHub
Feb 02, 2023
Browse files
Dynamic gathernd (#1480)
Dynamic shape support for gathernd op.
parent
dee20c6c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
516 additions
and
18 deletions
+516
-18
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+88
-17
src/program.cpp
src/program.cpp
+2
-1
test/onnx/gathernd_dyn_test.onnx
test/onnx/gathernd_dyn_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+13
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+18
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+214
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+181
-0
No files found.
src/include/migraphx/op/gathernd.hpp
View file @
d478675c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -47,33 +48,103 @@ struct gathernd
...
@@ -47,33 +48,103 @@ struct gathernd
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
i_shape
=
inputs
.
back
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
data_shape
=
inputs
.
front
();
auto
k
=
inputs
.
back
().
lens
().
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
i_shape
.
ndim
();
size_t
k
;
if
(
i_shape
.
dynamic
())
{
// the rank of the output is a function of k, so it must be fixed.
if
(
not
i_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
i_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
i_shape
.
lens
().
back
();
// Begin input validation checks.
int
output_ndim
=
int
(
q
)
+
r
-
k
-
batch_dims
-
1
;
if
(
k
>
r
-
batch_dims
)
if
(
k
>
r
-
batch_dims
)
{
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
std
::
to_string
(
r
-
batch_dims
));
}
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
if
(
batch_dims
>=
q
or
batch_dims
>=
r
)
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
{
MIGRAPHX_THROW
(
"GATHERND: rank of an input cannot be less than batch_dims="
+
std
::
to_string
(
batch_dims
));
}
if
(
output_ndim
<
0
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices too large for static data input: k="
+
std
::
to_string
(
k
));
}
if
(
migraphx
::
none_of
(
inputs
,
[](
auto
v
)
{
return
v
.
dynamic
();
}))
{
auto
indices_lens_iter
=
i_shape
.
lens
().
begin
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
{
data_shape
.
type
(),
{
1
}};
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
std
::
size_t
>
output_lens
(
output_ndim
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
{
auto
data_lens
=
inputs
.
front
().
lens
();
auto
data_lens
=
data_shape
.
lens
();
std
::
copy
(
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
}
shape
output_shape
{
inputs
.
front
()
.
type
(),
output_lens
};
shape
output_shape
{
data_shape
.
type
(),
output_lens
};
return
output_shape
;
return
output_shape
;
}
}
else
{
// If one or both inputs are dynamic shapes, the output is dynamic.
// Make both inputs dynamic to simplify computations.
data_shape
=
data_shape
.
to_dynamic
();
i_shape
=
i_shape
.
to_dynamic
();
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
copy
(
i_shape
.
dyn_dims
().
begin
(),
i_shape
.
dyn_dims
().
begin
()
+
q
-
1
,
output_dims
.
begin
());
// fill the rest of output shape from data tensor
if
(
k
+
batch_dims
<
r
)
{
auto
data_dims
=
data_shape
.
dyn_dims
();
std
::
copy
(
data_dims
.
begin
()
+
batch_dims
+
k
,
data_dims
.
begin
()
+
r
,
output_dims
.
begin
()
+
q
-
1
);
}
shape
output_shape
(
data_shape
.
type
(),
output_dims
);
return
output_shape
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape
=
indices
.
get_shape
();
...
...
src/program.cpp
View file @
d478675c
...
@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
...
@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
if
(
not
ins
->
get_shape
().
dynamic
()
and
param
.
get_shape
()
!=
ins
->
get_shape
())
{
{
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
"} for parameter: "
+
param_name
+
" should be: "
+
to_string
(
ins
->
get_shape
()));
}
}
return
param
;
return
param
;
}));
}));
...
...
test/onnx/gathernd_dyn_test.onnx
0 → 100644
View file @
d478675c
File added
test/onnx/gen_onnx.py
View file @
d478675c
...
@@ -2132,6 +2132,19 @@ def gathernd_test():
...
@@ -2132,6 +2132,19 @@ def gathernd_test():
return
([
node
],
[
x
,
i
],
[
y
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gathernd_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
2
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
])
node
=
onnx
.
helper
.
make_node
(
'GatherND'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
gathernd_batch_dims_test
():
def
gathernd_batch_dims_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
...
...
test/onnx/onnx_test.cpp
View file @
d478675c
...
@@ -2158,6 +2158,24 @@ TEST_CASE(gathernd_test)
...
@@ -2158,6 +2158,24 @@ TEST_CASE(gathernd_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gathernd_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
,
2
},
{
2
,
4
}}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{{
1
,
3
},
{
2
,
2
}}});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"gathernd"
),
l0
,
l1
);
mm
->
add_return
({
r
});
migraphx
::
onnx_options
options
;
options
.
map_dyn_input_dims
[
"data"
]
=
{{
2
,
4
,
2
},
{
2
,
4
}};
options
.
map_dyn_input_dims
[
"indices"
]
=
{{
1
,
3
},
{
2
,
2
}};
auto
prog
=
migraphx
::
parse_onnx
(
"gathernd_dyn_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gathernd_batch_dims_test
)
TEST_CASE
(
gathernd_batch_dims_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/op_shape_test.cpp
View file @
d478675c
...
@@ -2477,6 +2477,220 @@ TEST_CASE(test_scalar_nelemnts)
...
@@ -2477,6 +2477,220 @@ TEST_CASE(test_scalar_nelemnts)
throws_shape
(
migraphx
::
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
{
2
,
3
,
4
,
5
}}}),
input
);
throws_shape
(
migraphx
::
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
{
2
,
3
,
4
,
5
}}}),
input
);
}
}
TEST_CASE
(
test_gathernd
)
{
{
// k > r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
4
}};
migraphx
::
shape
ds
{
dtype
,
{
8
}};
int
batch_dims
(
1
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
{
// k > r - batch_dims
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
4
}};
migraphx
::
shape
ds
{
dtype
,
{
2
}};
int
batch_dims
(
1
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
{
// batch_dims >= r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
1
}};
migraphx
::
shape
ds
{
dtype
,
{
2
,
5
,
6
,
7
}};
int
batch_dims
(
3
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
1
}};
migraphx
::
shape
ds
{
dtype
,
{
2
}};
migraphx
::
shape
s0
{
dtype
,
{
1
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
),
ds
,
is
);
}
{
// See Example 4 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
2
}};
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
}};
migraphx
::
shape
s0
{
dtype
,
{
2
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
),
ds
,
is
);
}
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
1
}};
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
}};
int
batch_dims
(
1
);
migraphx
::
shape
s0
{
dtype
,
{
2
,
2
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
}
TEST_CASE
(
test_gathernd_dynamic0
)
{
// k > r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
4
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
8
,
8
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
int
batch_dims
(
1
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic1
)
{
// k > r - batch_dims
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
4
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
2
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
int
batch_dims
(
1
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic2
)
{
// batch_dims >= r
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
1
}};
migraphx
::
shape
ds
{
dtype
,
{{
2
,
3
,
3
},
{
5
,
6
,
5
},
{
6
,
9
,
7
},
{
7
,
8
,
8
}}};
int
batch_dims
(
3
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic3
)
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
1
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
2
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
migraphx
::
shape
::
dynamic_dimension
ddout
{
1
,
1
,
0
};
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic4
)
{
// See Example 1 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
2
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
2
,
0
},
{
2
,
2
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
migraphx
::
shape
::
dynamic_dimension
ddout
{
2
,
2
,
0
};
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic5
)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index static shape, data dynamic
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
1
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
2
,
0
},
{
2
,
2
,
0
},
{
2
,
2
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
ddout
{{
2
,
2
,
0
},
{
2
,
2
,
0
}};
int
batch_dims
(
1
);
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic6
)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index dynamic shape, data static
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
3
,
0
},
{
1
,
1
,
0
}};
migraphx
::
shape
is
{
itype
,
b
};
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
ddout
{{
2
,
3
,
0
},
{
2
,
2
,
0
}};
int
batch_dims
(
1
);
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic6a
)
{
// indices with non-fixed dynamic dimension k
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
2
,
0
},
{
1
,
3
,
0
}};
migraphx
::
shape
is
{
itype
,
b
};
migraphx
::
shape
ds
{
dtype
,
{
2
,
2
,
2
}};
int
batch_dims
(
1
);
throws_shape
(
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic7
)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index and data both dynamic shapes
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
idyn
{{
2
,
5
,
0
},
{
1
,
1
,
0
}};
migraphx
::
shape
is
{
itype
,
idyn
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
bdyn
{{
1
,
2
,
0
},
{
1
,
2
,
0
},
{
1
,
2
,
0
}};
migraphx
::
shape
ds
{
dtype
,
bdyn
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
ddout
{{
2
,
5
,
0
},
{
1
,
2
,
0
}};
int
batch_dims
(
1
);
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_gathernd_dynamic8
)
{
// Same shapes as ref_ops_test gathernd_dynamic
// index static shape, data dynamic
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
is
{
itype
,
{
2
,
5
,
1
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
6
,
7
,
7
},
{
3
,
3
,
0
},
{
1
,
4
,
0
}};
migraphx
::
shape
ds
{
dtype
,
b
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
ddout
{{
2
,
2
,
0
},
{
5
,
5
,
0
},
{
1
,
4
,
0
}};
int
batch_dims
(
1
);
migraphx
::
shape
s0
{
dtype
,
{
ddout
}};
expect_shape
(
s0
,
migraphx
::
make_op
(
"gathernd"
,
{{
"batch_dims"
,
batch_dims
}}),
ds
,
is
);
}
TEST_CASE
(
test_scatternd
)
TEST_CASE
(
test_scatternd
)
{
{
{
{
...
...
test/ref_ops_test.cpp
View file @
d478675c
...
@@ -2746,6 +2746,187 @@ TEST_CASE(gathernd_test)
...
@@ -2746,6 +2746,187 @@ TEST_CASE(gathernd_test)
}
}
}
}
TEST_CASE(gathernd_dynamic0)
{
// dynamic data, all dimensions fixed
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 2, 2}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic1)
{
// dynamic data, dims not fixed
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic2)
{
// dynamic both index and data
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic3)
{
// dynamic index, static data and a batch_dims input
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
int batch_dims{1};
auto gathernd_op = migraphx::make_op("gathernd", {{"batch_dims", batch_dims}});
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{1, 0, 3, 4};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic4)
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type,
{migraphx::shape::dynamic_dimension({2, 2, 0})}};
migraphx::shape is{migraphx::shape::int64_type, {1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {1}}; // index
std::vector<float> data_vec(2);
std::iota(data_vec.begin(), data_vec.end(), 4);
std::vector<int64_t> indices_vec{1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_negative_index_test)
TEST_CASE(gathernd_negative_index_test)
{
{
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment