Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
77b89a20
"vscode:/vscode.git/clone" did not exist on "7e465e4ad6e706398410395c8ae9a6d9259c6dd7"
Commit
77b89a20
authored
Nov 15, 2023
by
charlie
Browse files
finish matcher and make tests
parent
954bca60
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
16 deletions
+173
-16
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+7
-8
test/simplify_dyn_ops_test.cpp
test/simplify_dyn_ops_test.cpp
+166
-8
No files found.
src/simplify_dyn_ops.cpp
View file @
77b89a20
...
...
@@ -76,16 +76,15 @@ struct find_const_2in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
3
),
match
::
arg
(
1
)(
match
::
is_constant
()));
return
match
::
name
(
"slice"
)(
match
::
nargs
(
3
),
match
::
arg
(
1
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
auto
slice_op
=
any_cast
<
op
::
slice
>
(
ins
->
get_operator
());
auto
set_attrs
=
slice_op
.
get_set_attributes
();
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
auto
slice_op
=
any_cast
<
op
::
slice
>
(
ins
->
get_operator
());
auto
set_attrs
=
slice_op
.
get_set_attributes
();
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
axes_vec
;
...
...
@@ -103,7 +102,7 @@ struct find_const_2in_slice
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
starts_vec
=
slice_op
.
starts
;
axes_vec
=
slice_op
.
axes
;
axes_vec
=
slice_op
.
axes
;
}
else
{
...
...
@@ -111,7 +110,7 @@ struct find_const_2in_slice
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
starts_vec
=
slice_op
.
starts
;
ends_vec
=
slice_op
.
ends
;
ends_vec
=
slice_op
.
ends
;
}
m
.
replace_instruction
(
ins
,
...
...
test/simplify_dyn_ops_test.cpp
View file @
77b89a20
...
...
@@ -155,29 +155,187 @@ TEST_CASE(after_split_dyn_broadcast_match)
EXPECT
(
p0
==
p1
);
}
TEST_CASE
(
const_slice_
3
input
)
TEST_CASE
(
const_slice_
2
input
_ends_axes
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m0
.
add_instruction
(
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
,
input_starts
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_2input_starts_axes
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_ends
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"axes"
,
{
0
}}}),
input
,
input_ends
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_2input_starts_ends
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_ends
=
m1
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
slice_ins
=
m1
.
add_instruction
(
auto
input_axes
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}}}),
input
,
input_axes
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_3input_axes_only
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_ends
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}}}),
input
,
input_starts
,
input_ends
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_3input_ends_only
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
input_axes
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
}}}),
input
,
input_starts
,
input_axes
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_3inputs_starts_only
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_ends
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
3
}});
auto
input_axes
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}}}),
input
,
input_ends
,
input_axes
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
,
4
,
4
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
EXPECT
(
m0
==
m1
);
}
TEST_CASE
(
const_slice_2input_ends_axes_dyn
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
6
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
s1
{
migraphx
::
shape
::
int32_type
,
{
1
}};
auto
input_starts
=
m0
.
add_literal
(
migraphx
::
literal
{
s1
,
{
0
}});
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
,
input_starts
);
m0
.
add_return
({
slice_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
6
,
6
},
{
2
,
4
,
{
2
,
4
}},
{
2
,
4
,
{
2
,
4
}}}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
slice_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
0
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
input
);
m1
.
add_return
({
slice_ins
});
}
run_pass
(
m1
);
EXPECT
(
m0
==
m1
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment