Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
02915432
Commit
02915432
authored
Nov 10, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
a4c028ce
35e5298e
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
729 additions
and
151 deletions
+729
-151
src/include/migraphx/op/normalize_attribute.hpp
src/include/migraphx/op/normalize_attribute.hpp
+2
-0
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+322
-136
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+8
-8
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+3
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+230
-5
test/ref/slice.cpp
test/ref/slice.cpp
+163
-1
tools/accuracy/requirements.txt
tools/accuracy/requirements.txt
+1
-1
No files found.
src/include/migraphx/op/normalize_attribute.hpp
View file @
02915432
...
@@ -40,6 +40,8 @@ namespace op {
...
@@ -40,6 +40,8 @@ namespace op {
* 2. use_rank (default) vs use_len:
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* Uses the dynamic_dimension.max value for dynamic shapes. Returns the original vector
* (no normalization) if any of dynamic_dimension[axes] are not fixed.
* 3. `clip_min` vs. `not_clip_min` (default):
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* 4. `include_min` vs. `exclude_min` (default):
...
...
src/include/migraphx/op/slice.hpp
View file @
02915432
This diff is collapsed.
Click to expand it.
src/normalize_attributes.cpp
View file @
02915432
...
@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
ax
)
{
return
not
input_shape
.
dyn_dims
().
at
(
ax
).
is_fixed
();
}))
{
return
vec
;
}
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
max_vals
.
begin
(),
[
&
](
auto
i
)
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
max_vals
.
begin
(),
[
&
](
auto
i
)
{
const
auto
&
dd
=
input_shape
.
dyn_dims
().
at
(
i
);
return
input_shape
.
dyn_dims
().
at
(
i
).
max
;
if
(
not
dd
.
is_fixed
())
{
MIGRAPHX_THROW
(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis="
+
std
::
to_string
(
i
));
}
return
dd
.
max
;
});
});
}
}
else
else
...
...
src/onnx/parse_slice.cpp
View file @
02915432
...
@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
{
{
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
...
...
test/op_shape_test.cpp
View file @
02915432
...
@@ -3233,6 +3233,64 @@ TEST_CASE(slice_static_shape)
...
@@ -3233,6 +3233,64 @@ TEST_CASE(slice_static_shape)
TEST_CASE
(
slice_var_inputs_static_shape0
)
TEST_CASE
(
slice_var_inputs_static_shape0
)
{
{
// attr ends and axes set; inputs are (data, input_starts)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
2
,
3
}},
{
"axes"
,
{
1
,
2
}}}),
input
,
starts
);
}
TEST_CASE
(
slice_var_inputs_static_mismatch_error0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
2
,
3
,
4
}},
{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
starts
);
}
TEST_CASE
(
slice_var_inputs_static_shape1
)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
}},
{
"axes"
,
{
1
,
2
}}}),
input
,
ends
);
}
TEST_CASE
(
slice_var_inputs_static_mismatch_error1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}},
{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
ends
);
}
TEST_CASE
(
slice_var_inputs_static_shape2
)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
}},
{
"ends"
,
{
1
,
2
}}}),
input
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_mismatch_error2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}},
{
"ends"
,
{
3
,
4
,
4
}}}),
input
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_shape3
)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
...
@@ -3243,7 +3301,57 @@ TEST_CASE(slice_var_inputs_static_shape0)
...
@@ -3243,7 +3301,57 @@ TEST_CASE(slice_var_inputs_static_shape0)
ends
);
ends
);
}
}
TEST_CASE
(
slice_var_inputs_static_shape1
)
TEST_CASE
(
slice_var_inputs_static_mismatch_error3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
starts
,
ends
);
}
TEST_CASE
(
slice_var_inputs_static_shape4
)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
,
4
}}}),
input
,
starts
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_mismatch_error4
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
,
3
,
3
}}}),
input
,
starts
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_shape5
)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
3
},
{
0
,
4
},
{
0
,
4
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
2
}}}),
input
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_mismatch_error5
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}}}),
input
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_static_shape6
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
...
@@ -3257,7 +3365,7 @@ TEST_CASE(slice_var_inputs_static_shape1)
...
@@ -3257,7 +3365,7 @@ TEST_CASE(slice_var_inputs_static_shape1)
axes
);
axes
);
}
}
TEST_CASE
(
slice_var_inputs_static_error
0
)
TEST_CASE
(
slice_var_inputs_static_
mismatch_
error
6
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
...
@@ -3268,17 +3376,125 @@ TEST_CASE(slice_var_inputs_static_error0)
...
@@ -3268,17 +3376,125 @@ TEST_CASE(slice_var_inputs_static_error0)
TEST_CASE
(
slice_var_inputs_dyn_shape0
)
TEST_CASE
(
slice_var_inputs_dyn_shape0
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
// attr ends and axes set; inputs are (data, input_starts)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
2
,
3
}},
{
"axes"
,
{
1
,
2
}}}),
input
,
starts
);
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
2
,
3
,
4
}},
{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
starts
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape1
)
{
// attr starts and axes set; inputs are (data, input_ends)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
0
,
4
},
{
0
,
4
}}},
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
}},
{
"axes"
,
{
1
,
2
}}}),
input
,
ends
);
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}},
{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
ends
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape2
)
{
// attr starts and ends set; inputs are (data, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
}},
{
"ends"
,
{
8
,
8
}}}),
input
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}},
{
"ends"
,
{
3
,
4
,
4
}}}),
input
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape3
)
{
// attr axes set; inputs are (data, input_starts, input_ends)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
2
}}}),
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
,
2
}}}),
input
,
input
,
starts
,
starts
,
ends
);
ends
);
}
}
TEST_CASE
(
slice_var_inputs_dyn_shape1
)
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
,
1
,
2
}}}),
input
,
starts
,
ends
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape4
)
{
// attr ends set; inputs are (data, input_starts, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
,
4
}}}),
input
,
starts
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error4
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
,
3
,
3
}}}),
input
,
starts
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape5
)
{
// attr starts set; inputs are (data, input_ends, input_axes)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
0
,
6
},
{
0
,
6
},
{
0
,
6
}}},
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
2
}}}),
input
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error5
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
2
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
,
1
,
2
}}}),
input
,
ends
,
axes
);
}
TEST_CASE
(
slice_var_inputs_dyn_shape6
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
...
@@ -3292,6 +3508,15 @@ TEST_CASE(slice_var_inputs_dyn_shape1)
...
@@ -3292,6 +3508,15 @@ TEST_CASE(slice_var_inputs_dyn_shape1)
axes
);
axes
);
}
}
TEST_CASE
(
slice_var_inputs_dyn_mismatch_error6
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
3
,
6
},
{
4
,
6
},
{
4
,
6
}}};
migraphx
::
shape
starts
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
ends
{
migraphx
::
shape
::
int64_type
,
{
2
}};
migraphx
::
shape
axes
{
migraphx
::
shape
::
int64_type
,
{
3
}};
throws_shape
(
migraphx
::
make_op
(
"slice"
),
input
,
starts
,
ends
,
axes
);
}
TEST_CASE
(
slice_dyn_shape0
)
TEST_CASE
(
slice_dyn_shape0
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
7
},
{
2
,
3
}}};
migraphx
::
shape
input
{
migraphx
::
shape
::
int32_type
,
{{
2
,
3
},
{
7
,
7
},
{
2
,
3
}}};
...
...
test/ref/slice.cpp
View file @
02915432
...
@@ -157,7 +157,169 @@ TEST_CASE(slice_var_inputs_static2)
...
@@ -157,7 +157,169 @@ TEST_CASE(slice_var_inputs_static2)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
slice_var_inputs_dyn
)
TEST_CASE
(
slice_var_inputs_dyn0
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
starts
=
mm
->
add_parameter
(
"starts"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"ends"
,
{
10
}}}),
input
,
starts
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
start_data
=
{
1
};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"starts"
]
=
migraphx
::
argument
(
s1
,
start_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
1
,
2
,
4
,
5
,
7
,
8
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn1
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
ends
=
mm
->
add_parameter
(
"ends"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
-
5
}}}),
input
,
ends
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
ends_data
=
{
3
};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"ends"
]
=
migraphx
::
argument
(
s1
,
ends_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn2
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
axes
=
mm
->
add_parameter
(
"axes"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
1
}},
{
"ends"
,
{
-
1
}}}),
input
,
axes
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
axes_data
=
{
2
};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"axes"
]
=
migraphx
::
argument
(
s1
,
axes_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
1
,
4
,
7
,
10
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
1
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn3
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
starts
=
mm
->
add_parameter
(
"starts"
,
s1
);
auto
ends
=
mm
->
add_parameter
(
"ends"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}}}),
input
,
starts
,
ends
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
starts_data
=
{
1
};
std
::
vector
<
int
>
ends_data
=
{
std
::
numeric_limits
<
int
>::
max
()};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"starts"
]
=
migraphx
::
argument
(
s1
,
starts_data
.
data
());
params
[
"ends"
]
=
migraphx
::
argument
(
s1
,
ends_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
1
,
2
,
4
,
5
,
7
,
8
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn4
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
starts
=
mm
->
add_parameter
(
"starts"
,
s1
);
auto
axes
=
mm
->
add_parameter
(
"axes"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
std
::
numeric_limits
<
int
>::
max
()}}}),
input
,
starts
,
axes
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
starts_data
=
{
1
};
std
::
vector
<
int
>
axes_data
=
{
2
};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"starts"
]
=
migraphx
::
argument
(
s1
,
starts_data
.
data
());
params
[
"axes"
]
=
migraphx
::
argument
(
s1
,
axes_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
1
,
2
,
4
,
5
,
7
,
8
,
10
,
11
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn5
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
int32_type
,
{{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}},
{
3
,
8
}}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s0
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
ends
=
mm
->
add_parameter
(
"ends"
,
s1
);
auto
axes
=
mm
->
add_parameter
(
"axes"
,
s1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
-
4
}}}),
input
,
ends
,
axes
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
parameter_map
params
;
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
std
::
vector
<
int
>
input_data
(
2
*
2
*
3
);
std
::
iota
(
input_data
.
begin
(),
input_data
.
end
(),
0
);
std
::
vector
<
int
>
ends_data
=
{
2
};
std
::
vector
<
int
>
axes_data
=
{
2
};
params
[
"input"
]
=
migraphx
::
argument
(
s2
,
input_data
.
data
());
params
[
"ends"
]
=
migraphx
::
argument
(
s1
,
ends_data
.
data
());
params
[
"axes"
]
=
migraphx
::
argument
(
s1
,
axes_data
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int
>
gold
=
{
0
,
1
,
3
,
4
,
6
,
7
,
9
,
10
};
std
::
vector
<
int
>
results_vector
(
2
*
2
*
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
results_vector
,
gold
));
}
TEST_CASE
(
slice_var_inputs_dyn6
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
...
...
tools/accuracy/requirements.txt
View file @
02915432
...
@@ -22,4 +22,4 @@
...
@@ -22,4 +22,4 @@
# THE SOFTWARE.
# THE SOFTWARE.
#####################################################################################
#####################################################################################
numpy==1.21.6
numpy==1.21.6
onnxruntime==1.16.
1
onnxruntime==1.16.
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment