Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7110eb0e
"...composable_kernel_rocm.git" did not exist on "a93d07c74aa6f87b9fd2cb0ec81b04d16498c7d2"
Commit
7110eb0e
authored
Apr 03, 2023
by
Paul
Browse files
Format
parent
49dc6d12
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
16 deletions
+23
-16
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+23
-16
No files found.
src/simplify_reshapes.cpp
View file @
7110eb0e
...
@@ -587,15 +587,15 @@ struct find_reshape_cont
...
@@ -587,15 +587,15 @@ struct find_reshape_cont
};
};
// match sequence of transpose --> contiguous --> reshaper_op
// match sequence of transpose --> contiguous --> reshaper_op
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
match_transpose_contiguous_reshaper
(
Ms
...
ms
)
auto
match_transpose_contiguous_reshaper
(
Ms
...
ms
)
{
{
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
used_once
(),
match
::
used_once
(),
match
::
args
(
match
::
args
(
match
::
name
(
"contiguous"
)(
match
::
name
(
"contiguous"
)(
match
::
used_once
(),
match
::
used_once
(),
match
::
args
(
match
::
transpose_shape
(
ms
...).
bind
(
"trans_ins"
)))
match
::
args
(
match
::
transpose_shape
(
ms
...).
bind
(
"trans_ins"
)))
.
bind
(
"cont_ins"
)))
.
bind
(
"cont_ins"
)))
.
bind
(
"reshaper_ins"
);
.
bind
(
"reshaper_ins"
);
};
};
...
@@ -631,29 +631,36 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm
...
@@ -631,29 +631,36 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
pw
=
match
::
name
(
"mul"
,
"add"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"c"
),
match
::
any
().
bind
(
"x"
)));
auto
pw
=
match
::
name
(
"mul"
,
"add"
)(
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match_transpose_contiguous_reshaper
(
pw
.
bind
(
"pointwise"
)),
match
::
is_constant
()));
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"c"
),
match
::
any
().
bind
(
"x"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match_transpose_contiguous_reshaper
(
pw
.
bind
(
"pointwise"
)),
match
::
is_constant
()));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
pw_ins
=
r
.
instructions
[
"pointwise"
];
auto
pw_ins
=
r
.
instructions
[
"pointwise"
];
auto
insert_reshapes
=
[
&
](
auto
x
)
{
auto
insert_reshapes
=
[
&
](
auto
x
)
{
auto
t
=
m
.
insert_instruction
(
ins
,
trans_ins
->
get_operator
(),
x
);
auto
t
=
m
.
insert_instruction
(
ins
,
trans_ins
->
get_operator
(),
x
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
t
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
t
);
return
m
.
insert_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
c
);
return
m
.
insert_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
c
);
};
};
if
(
x_ins
->
name
()
==
"mul"
)
if
(
x_ins
->
name
()
==
"mul"
)
{
{
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
{
insert_reshapes
(
x_ins
->
inputs
()[
0
]),
insert_reshapes
(
x_ins
->
inputs
()[
1
])});
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
{
insert_reshapes
(
x_ins
->
inputs
()[
0
]),
insert_reshapes
(
x_ins
->
inputs
()[
1
])});
}
}
auto
y_ins
=
m
.
insert_instruction
(
ins
,
pw_ins
->
get_operator
(),
{
x_ins
,
insert_reshapes
(
c_ins
)});
auto
y_ins
=
m
.
insert_instruction
(
ins
,
pw_ins
->
get_operator
(),
{
x_ins
,
insert_reshapes
(
c_ins
)});
m
.
replace_instruction
(
reshaper_ins
,
y_ins
);
m
.
replace_instruction
(
reshaper_ins
,
y_ins
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment