Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e213e87e
Commit
e213e87e
authored
Mar 10, 2023
by
Paul
Browse files
Format
parent
73f78d8d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
14 deletions
+16
-14
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+16
-14
No files found.
src/simplify_reshapes.cpp
View file @
e213e87e
...
@@ -799,25 +799,23 @@ struct find_transpose_slice
...
@@ -799,25 +799,23 @@ struct find_transpose_slice
struct
find_reshape_gemm
struct
find_reshape_gemm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
return
match
::
name
(
"reshape"
)(
match
::
arg
(
0
)(
match
::
name
(
"dot"
)));
}
{
return
match
::
name
(
"reshape"
)(
match
::
arg
(
0
)(
match
::
name
(
"dot"
)));
}
static
bool
is_batched_unsqueeze
(
instruction_ref
ins
)
static
bool
is_batched_unsqueeze
(
instruction_ref
ins
)
{
{
auto
input
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
input
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
output
=
ins
->
get_shape
().
lens
();
auto
output
=
ins
->
get_shape
().
lens
();
if
(
output
.
size
()
<=
input
.
size
())
if
(
output
.
size
()
<=
input
.
size
())
return
false
;
return
false
;
if
(
not
std
::
equal
(
input
.
end
()
-
2
,
input
.
end
(),
output
.
end
()
-
2
,
output
.
end
()))
if
(
not
std
::
equal
(
input
.
end
()
-
2
,
input
.
end
(),
output
.
end
()
-
2
,
output
.
end
()))
return
false
;
return
false
;
return
true
;
return
true
;
}
}
static
operation
make_reshape
(
std
::
vector
<
std
::
size_t
>
batches
,
instruction_ref
ins
)
static
operation
make_reshape
(
std
::
vector
<
std
::
size_t
>
batches
,
instruction_ref
ins
)
{
{
batches
.
insert
(
batches
.
end
(),
ins
->
get_shape
().
lens
().
end
()
-
2
,
ins
->
get_shape
().
lens
().
end
());
batches
.
insert
(
batches
.
end
(),
ins
->
get_shape
().
lens
().
end
()
-
2
,
ins
->
get_shape
().
lens
().
end
());
return
make_op
(
"reshape"
,
{{
"dims"
,
batches
}});
return
make_op
(
"reshape"
,
{{
"dims"
,
batches
}});
}
}
...
@@ -827,14 +825,18 @@ struct find_reshape_gemm
...
@@ -827,14 +825,18 @@ struct find_reshape_gemm
auto
dot_ins
=
reshape_ins
->
inputs
().
front
();
auto
dot_ins
=
reshape_ins
->
inputs
().
front
();
// TODO: Put this in the matcher
// TODO: Put this in the matcher
if
(
not
is_batched_unsqueeze
(
reshape_ins
))
if
(
not
is_batched_unsqueeze
(
reshape_ins
))
return
;
return
;
std
::
vector
<
std
::
size_t
>
batches
;
std
::
vector
<
std
::
size_t
>
batches
;
std
::
copy
(
reshape_ins
->
get_shape
().
lens
().
begin
(),
reshape_ins
->
get_shape
().
lens
().
end
()
-
2
,
std
::
back_inserter
(
batches
));
std
::
copy
(
reshape_ins
->
get_shape
().
lens
().
begin
(),
reshape_ins
->
get_shape
().
lens
().
end
()
-
2
,
auto
input0
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
0
]),
dot_ins
->
inputs
()[
0
]);
std
::
back_inserter
(
batches
));
auto
input1
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
1
]),
dot_ins
->
inputs
()[
1
]);
auto
input0
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
0
]),
dot_ins
->
inputs
()[
0
]);
auto
input1
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
1
]),
dot_ins
->
inputs
()[
1
]);
m
.
replace_instruction
(
dot_ins
,
make_op
(
"dot"
),
input0
,
input1
);
m
.
replace_instruction
(
dot_ins
,
make_op
(
"dot"
),
input0
,
input1
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment