Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
73f78d8d
Commit
73f78d8d
authored
Mar 10, 2023
by
Paul
Browse files
Improve handling of reshape around gemm
parent
69799cb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
1 deletion
+44
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+44
-1
No files found.
src/simplify_reshapes.cpp
View file @
73f78d8d
...
@@ -797,6 +797,48 @@ struct find_transpose_slice
...
@@ -797,6 +797,48 @@ struct find_transpose_slice
}
}
};
};
struct
find_reshape_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"reshape"
)(
match
::
arg
(
0
)(
match
::
name
(
"dot"
)));
}
static
bool
is_batched_unsqueeze
(
instruction_ref
ins
)
{
auto
input
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
output
=
ins
->
get_shape
().
lens
();
if
(
output
.
size
()
<=
input
.
size
())
return
false
;
if
(
not
std
::
equal
(
input
.
end
()
-
2
,
input
.
end
(),
output
.
end
()
-
2
,
output
.
end
()))
return
false
;
return
true
;
}
static
operation
make_reshape
(
std
::
vector
<
std
::
size_t
>
batches
,
instruction_ref
ins
)
{
batches
.
insert
(
batches
.
end
(),
ins
->
get_shape
().
lens
().
end
()
-
2
,
ins
->
get_shape
().
lens
().
end
());
return
make_op
(
"reshape"
,
{{
"dims"
,
batches
}});
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
reshape_ins
=
r
.
result
;
auto
dot_ins
=
reshape_ins
->
inputs
().
front
();
// TODO: Put this in the matcher
if
(
not
is_batched_unsqueeze
(
reshape_ins
))
return
;
std
::
vector
<
std
::
size_t
>
batches
;
std
::
copy
(
reshape_ins
->
get_shape
().
lens
().
begin
(),
reshape_ins
->
get_shape
().
lens
().
end
()
-
2
,
std
::
back_inserter
(
batches
));
auto
input0
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
0
]),
dot_ins
->
inputs
()[
0
]);
auto
input1
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
1
]),
dot_ins
->
inputs
()[
1
]);
m
.
replace_instruction
(
dot_ins
,
make_op
(
"dot"
),
input0
,
input1
);
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
@@ -815,7 +857,8 @@ void simplify_reshapes::apply(module& m) const
...
@@ -815,7 +857,8 @@ void simplify_reshapes::apply(module& m) const
find_nested_concat
{},
find_nested_concat
{},
find_transpose_slice
{},
find_transpose_slice
{},
find_slice_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{},
find_reshape_gemm
{});
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment