Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
59f6009b
Commit
59f6009b
authored
Aug 25, 2023
by
Khalique Ahmed
Browse files
initial testing workaround
parent
e46a6a52
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
0 deletions
+74
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+49
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+25
-0
No files found.
src/simplify_reshapes.cpp
View file @
59f6009b
...
@@ -627,6 +627,54 @@ struct find_transpose_contiguous_reshaper_unary
...
@@ -627,6 +627,54 @@ struct find_transpose_contiguous_reshaper_unary
}
}
};
};
struct
find_broadcast_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"multibroadcast"
)(
match
::
all_of
[
match
::
outputs
()](
match
::
name
(
"transpose"
).
bind
(
"trans_ins"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
input
=
ins
->
inputs
().
front
();
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
trans_shape
=
trans_ins
->
get_shape
();
auto
permutation
=
trans_ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
// {1, 3072, 1024}
// unsqueeze axes {0, 1}
// need unsqueeze to be {0, 2}
// transpose permutation {0, 2, 1}
std
::
vector
<
int64_t
>
unsqueeze_axes
{};
auto
unsqueeze_count
=
-
1
;
for
(
auto
dim
:
ins_lens
)
{
unsqueeze_count
++
;
if
(
contains
(
input_lens
,
dim
))
{
continue
;
}
unsqueeze_axes
.
push_back
(
unsqueeze_count
);
}
for
(
auto
i
=
0
;
i
<
unsqueeze_axes
.
size
();
i
++
)
{
if
(
permutation
[
i
]
>
unsqueeze_axes
[
i
])
{
unsqueeze_axes
[
i
]
+=
permutation
[
i
]
-
i
;
}
}
auto
unsqueeze_ins
=
m
.
insert_instruction
(
trans_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
auto
mbcast_ins
=
m
.
insert_instruction
(
trans_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
trans_shape
.
lens
()}}),
unsqueeze_ins
);
m
.
replace_instruction
(
trans_ins
,
mbcast_ins
);
}
};
struct
find_slice_transpose
struct
find_slice_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -793,6 +841,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -793,6 +841,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper
{},
find_reshaper
{},
find_reshape_cont
{},
find_reshape_cont
{},
find_transpose
{},
find_transpose
{},
find_broadcast_transpose
{},
find_concat_transpose
{},
find_concat_transpose
{},
find_concat_multibroadcasts
{},
find_concat_multibroadcasts
{},
find_nested_convert
{},
find_nested_convert
{},
...
...
test/simplify_reshapes_test.cpp
View file @
59f6009b
...
@@ -201,6 +201,31 @@ TEST_CASE(reshape_transpose)
...
@@ -201,6 +201,31 @@ TEST_CASE(reshape_transpose)
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
);
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
);
}
}
TEST_CASE
(
broadcast_transpose
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1024
}});
auto
mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
3072
,
1024
}}}),
l
);
auto
t1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
mb
);
mm
->
add_return
({
t1
});
run_pass
(
*
mm
);
migraphx
::
program
p2
;
mm
=
p2
.
get_main_module
();
l
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1024
}});
auto
unsqueeze
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
2
}}}),
l
);
mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1024
,
3072
}}}),
unsqueeze
);
mm
->
add_return
({
mb
});
EXPECT
(
p
==
p2
);
// EXPECT(not mm->get_output_shapes().back().standard());
// EXPECT(mm->get_output_shapes().back().transposed());
// EXPECT(std::distance(mm->begin(), mm->end()) == 3);
// auto result = p.eval({}).back();
// EXPECT(result != get_2x2());
}
TEST_CASE
(
transpose_contiguous
)
TEST_CASE
(
transpose_contiguous
)
{
{
migraphx
::
module
m
;
migraphx
::
module
m
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment