Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e055c6e5
"testing/vscode:/vscode.git/clone" did not exist on "1f4bfe6695a2a858796ab7d6c916641ef8cc492b"
Commit
e055c6e5
authored
Nov 09, 2023
by
charlie
Browse files
make pass and add tests
parent
2c8ba41a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
104 additions
and
1 deletion
+104
-1
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+37
-1
test/simplify_dyn_ops_test.cpp
test/simplify_dyn_ops_test.cpp
+67
-0
No files found.
src/simplify_dyn_ops.cpp
View file @
e055c6e5
...
@@ -33,6 +33,10 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -33,6 +33,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
* of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
*/
struct
find_static_2in_broadcasts
struct
find_static_2in_broadcasts
{
{
...
@@ -172,11 +176,43 @@ struct find_static_dimensions_of
...
@@ -172,11 +176,43 @@ struct find_static_dimensions_of
}
}
};
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct
find_const_alloc_reshapes
{
auto
matcher
()
const
{
return
match
::
name
(
"reshape"
)(
match
::
nargs
(
2
),
match
::
arg
(
1
)(
match
::
name
(
"allocate"
)(
match
::
is_constant
())));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
reshape_ins
=
mr
.
result
;
auto
reshape_inputs
=
reshape_ins
->
inputs
();
auto
alloc_ins
=
reshape_inputs
.
at
(
1
);
argument
output_dims_arg
=
alloc_ins
->
inputs
().
at
(
0
)
->
eval
();
std
::
vector
<
int64_t
>
output_dims_vec
;
output_dims_arg
.
visit
(
[
&
](
auto
output
)
{
output_dims_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
m
.
replace_instruction
(
reshape_ins
,
make_op
(
"reshape"
,
{{
"dims"
,
output_dims_vec
}}),
reshape_inputs
.
at
(
0
));
// have dead_code_elimination remove the previous allocate
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_static_dimensions_of
{});
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_const_alloc_reshapes
{},
find_static_2in_broadcasts
{},
find_static_2in_broadcasts
{},
find_static_dimensions_of
{},
find_const_3in_slice
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
find_const_4in_slice
{});
}
}
...
...
test/simplify_dyn_ops_test.cpp
View file @
e055c6e5
...
@@ -319,4 +319,71 @@ TEST_CASE(static_dimensions_of_nonfixed)
...
@@ -319,4 +319,71 @@ TEST_CASE(static_dimensions_of_nonfixed)
EXPECT
(
m0
==
m1
);
EXPECT
(
m0
==
m1
);
}
}
TEST_CASE
(
constant_alloc_reshape
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
32
}};
auto
input
=
m0
.
add_parameter
(
"data"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
::
int64_type
,
{
3
}};
auto
literal_ins
=
m0
.
add_literal
(
migraphx
::
literal
{
lit_s
,
{
3
,
4
,
8
}});
auto
alloc_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"allocate"
,
{{
"buf_type"
,
migraphx
::
shape
::
float_type
}}),
literal_ins
);
auto
reshape_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
input
,
alloc_ins
);
m0
.
add_return
({
reshape_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
32
}};
auto
input
=
m1
.
add_parameter
(
"data"
,
s
);
auto
reshape_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
8
}}}),
input
);
m1
.
add_return
({
reshape_ins
});
}
EXPECT
(
m0
==
m1
);
}
// A more contrived example to test static dimensions_of and constant reshape
TEST_CASE
(
static_dimensions_of_to_constant_alloc_reshape
)
{
migraphx
::
module
m0
;
{
migraphx
::
shape
input_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
8
}};
auto
x_param
=
m0
.
add_parameter
(
"x"
,
input_shape
);
auto
dimensions_of_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"dimensions_of"
,
{{
"end"
,
3
}}),
x_param
);
migraphx
::
shape
lit_shape
{
migraphx
::
shape
::
int64_type
,
{
1
}};
auto
lit0
=
m0
.
add_literal
(
migraphx
::
literal
{
lit_shape
,
{
0
}});
auto
gather_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
dimensions_of_ins
,
lit0
);
auto
slice_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}},
{
"axes"
,
{
0
}}}),
dimensions_of_ins
);
auto
reduce_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"reduce_prod"
,
{{
"axes"
,
{
0
}}}),
slice_ins
);
auto
concat_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
gather_ins
,
reduce_ins
);
auto
alloc_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"allocate"
,
{{
"buf_type"
,
migraphx
::
shape
::
float_type
}}),
concat_ins
);
auto
reshape_ins
=
m0
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
),
x_param
,
alloc_ins
);
m0
.
add_return
({
reshape_ins
});
}
run_pass
(
m0
);
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
8
}};
auto
x_param
=
m1
.
add_parameter
(
"x"
,
s
);
auto
reshape_ins
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
32
}}}),
x_param
);
m1
.
add_return
({
reshape_ins
});
}
EXPECT
(
m0
==
m1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment