Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c72d53ba
Unverified
Commit
c72d53ba
authored
Jan 30, 2023
by
Brian Pickrell
Committed by
GitHub
Jan 30, 2023
Browse files
Dyn gather (#1513)
Dynamic shape support for gather op.
parent
5bcb7ce8
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
257 additions
and
15 deletions
+257
-15
src/include/migraphx/op/gather.hpp
src/include/migraphx/op/gather.hpp
+40
-15
test/onnx/gather_dyn_test.onnx
test/onnx/gather_dyn_test.onnx
+0
-0
test/onnx/gather_scalar_test.onnx
test/onnx/gather_scalar_test.onnx
+0
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+34
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+40
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+71
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+72
-0
No files found.
src/include/migraphx/op/gather.hpp
View file @
c72d53ba
...
...
@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
...
...
@@ -61,13 +62,36 @@ struct gather
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
auto
type
=
inputs
[
0
].
type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
shape
data
=
inputs
[
0
];
shape
indices
=
inputs
[
1
];
auto
type
=
data
.
type
();
// If index_dims is dynamic, convert the data to dynamic too.
if
(
indices
.
dynamic
())
{
data
=
data
.
to_dynamic
();
}
if
(
data
.
dynamic
())
{
auto
dims
=
data
.
dyn_dims
();
dims
.
erase
(
dims
.
begin
()
+
axis
);
if
(
not
indices
.
scalar
())
{
auto
index_dims
=
indices
.
to_dynamic
().
dyn_dims
();
dims
.
insert
(
dims
.
begin
()
+
axis
,
index_dims
.
begin
(),
index_dims
.
end
());
}
return
{
type
,
dims
};
}
else
{
// Both data and indices are static. indices may be scalar
auto
lens
=
data
.
lens
();
lens
.
erase
(
lens
.
begin
()
+
axis
);
if
(
not
inputs
[
1
].
scalar
())
if
(
not
indices
.
scalar
())
{
auto
ind_lens
=
in
puts
[
1
]
.
lens
();
auto
ind_lens
=
in
dices
.
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
...
...
@@ -79,17 +103,18 @@ struct gather
return
{
type
,
lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
// negative axis means counting dimensions from back
auto
lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
axis_dim_size
=
lens
[
axis
];
// max dimension in axis
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
if
(
out
put_shape
.
scalar
())
if
(
dyn_out
.
com
put
ed
_shape
.
scalar
())
{
auto
in_index
=
indices
.
front
();
in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
...
...
test/onnx/gather_dyn_test.onnx
0 → 100644
View file @
c72d53ba
File added
test/onnx/gather_scalar_test.onnx
0 → 100644
View file @
c72d53ba
File added
test/onnx/gen_onnx.py
View file @
c72d53ba
...
...
@@ -2053,6 +2053,40 @@ def gather_test():
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gather_scalar_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
4
,
5
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Gather'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
],
axis
=
1
,
)
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gather_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
4
,
5
,
6
])
i
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT32
,
[
None
,
3
,
4
,
5
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
2
,
3
,
4
,
5
])
node
=
onnx
.
helper
.
make_node
(
'Gather'
,
inputs
=
[
'data'
,
'indices'
],
outputs
=
[
'y'
],
axis
=
1
,
)
return
([
node
],
[
x
,
i
],
[
y
])
@
onnx_test
()
def
gather_elements_axis0_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
])
...
...
test/onnx/onnx_test.cpp
View file @
c72d53ba
...
...
@@ -2048,6 +2048,46 @@ TEST_CASE(gather_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_scalar_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
std
::
vector
<
size_t
>
idims
{
1
};
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
idims
,
{
0
}});
int
axis
=
1
;
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
l0
,
l1
);
auto
prog
=
optimize_onnx
(
"gather_scalar_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
5
,
5
,
0
},
{
6
,
6
,
0
}}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
5
,
5
,
0
}}});
auto
cont_l0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l0
);
auto
cont_l1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
l1
);
int
axis
=
1
;
auto
gather_op
=
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}});
auto
ret
=
mm
->
add_instruction
(
gather_op
,
cont_l0
,
cont_l1
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
4
,
0
};
auto
prog
=
parse_onnx
(
"gather_dyn_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
gather_elements_axis0_test
)
{
migraphx
::
program
p
;
...
...
test/op_shape_test.cpp
View file @
c72d53ba
...
...
@@ -831,6 +831,77 @@ TEST_CASE(gather)
}
}
TEST_CASE
(
gather_dyn0
)
{
// Insert dynamic index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
7
,
3
},
{
3
,
3
,
0
}}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
2
,
7
,
3
},
{
3
,
3
,
0
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn1
)
{
// Insert static index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
2
,
2
,
0
},
{
3
,
3
,
0
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn2
)
{
// Insert scalar (static) index into dynamic shape
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}};
std
::
vector
<
std
::
size_t
>
mins
;
std
::
vector
<
std
::
size_t
>
maxes
;
std
::
vector
<
std
::
size_t
>
opts
;
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
mins
,
maxes
,
opts
};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
6
,
9
,
7
},
{
12
,
14
,
13
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn3
)
{
// Insert dynamic index into static shape, axis 1
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
6
,
12
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
}}};
int
axis
=
1
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
2
,
0
},
{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
6
,
6
,
0
},
{
12
,
12
,
0
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
gather_dyn4
)
{
// Insert dynamic index into static shape, axis 0
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
6
,
12
}};
migraphx
::
shape
indices
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
}}};
int
axis
=
0
;
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
2
,
3
,
2
},
{
3
,
4
,
3
},
{
3
,
3
,
0
},
{
6
,
6
,
0
},
{
12
,
12
,
0
}}},
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
axis
}}),
input
,
indices
);
}
TEST_CASE
(
get_tuple_elem_test
)
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
bool_type
,
{
1
,
1
}};
...
...
test/ref_ops_test.cpp
View file @
c72d53ba
...
...
@@ -2524,6 +2524,78 @@ TEST_CASE(gather_test)
}
}
TEST_CASE
(
gather_dyn_test0
)
{
// Dynamic data, static indices
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{{
2
,
5
,
0
},
{
3
,
3
,
0
}}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
std
::
vector
<
int
>
indices
{
1
,
2
};
migraphx
::
shape
s_ind
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
auto
ind
=
mm
->
add_parameter
(
"indices"
,
s_ind
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
1
}}),
x
,
ind
);
migraphx
::
shape
sresult
{
migraphx
::
shape
::
int32_type
,
{{
2
,
5
,
0
},
{
1
,
1
,
0
},
{
2
,
2
,
0
}}};
EXPECT
(
p
.
get_output_shapes
().
back
()
==
sresult
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
input_indices
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
migraphx
::
parameter_map
params
;
std
::
vector
<
int
>
data
(
2
*
3
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
params
[
"x"
]
=
migraphx
::
argument
(
input_fixed_shape
,
data
.
data
());
params
[
"indices"
]
=
migraphx
::
argument
(
input_indices
,
indices
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
1
,
2
,
4
,
5
};
std
::
vector
<
int
>
results_vector
(
2
*
1
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
migraphx
::
shape
sfinal
{
migraphx
::
shape
::
int32_type
,
{
2
,
1
,
2
}};
EXPECT
(
result
.
get_shape
()
==
sfinal
);
}
TEST_CASE
(
gather_dyn_test1
)
{
// Dynamic data, dynamic indices
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{{
2
,
5
,
0
},
{
4
,
4
,
0
}}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
migraphx
::
shape
s_ind
{
migraphx
::
shape
::
int32_type
,
{{
1
,
8
,
7
},
{
2
,
3
,
3
}}};
auto
ind
=
mm
->
add_parameter
(
"indices"
,
s_ind
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
x
,
ind
);
migraphx
::
shape
sresult
{
migraphx
::
shape
::
int32_type
,
{{
1
,
8
,
7
},
{
2
,
3
,
3
},
{
4
,
4
,
0
}}};
EXPECT
(
p
.
get_output_shapes
().
back
()
==
sresult
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
4
}};
migraphx
::
shape
input_indices_shape
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
}};
std
::
vector
<
int
>
indices
{
2
,
0
};
migraphx
::
parameter_map
params
;
std
::
vector
<
int
>
data
(
3
*
4
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
params
[
"x"
]
=
migraphx
::
argument
(
input_fixed_shape
,
data
.
data
());
params
[
"indices"
]
=
migraphx
::
argument
(
input_indices_shape
,
indices
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
8
,
9
,
10
,
11
,
0
,
1
,
2
,
3
};
std
::
vector
<
int
>
results_vector
(
1
*
2
*
4
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
migraphx
::
shape
sfinal
{
migraphx
::
shape
::
int32_type
,
{
1
,
2
,
4
}};
EXPECT
(
result
.
get_shape
()
==
sfinal
);
}
TEST_CASE
(
gathernd_test
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment