Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b3af63ac
Commit
b3af63ac
authored
Aug 26, 2022
by
Paul
Browse files
Horizontally fuse gemms that share the same weights
parent
1704bb04
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
0 deletions
+55
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+55
-0
No files found.
src/simplify_algebra.cpp
View file @
b3af63ac
...
@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion
...
@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion
}
}
};
};
MIGRAPHX_PRED_MATCHER
(
horiz_dot_weights
,
instruction_ref
ins
)
{
auto
pred
=
[
&
](
auto
name
)
{
return
[
=
](
auto
i
)
{
return
i
->
name
()
==
name
and
i
->
inputs
().
back
()
==
ins
;
};
};
return
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
))
>
1
;
}
struct
find_dot_horiz_fusion_weights
{
auto
matcher
()
const
{
return
horiz_dot_weights
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
std
::
vector
<
instruction_ref
>
dots
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
dots
),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"dot"
and
i
->
inputs
().
back
()
==
ins
;
});
std
::
sort
(
dots
.
begin
(),
dots
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
return
std
::
distance
(
ins
,
i
);
}));
// Check if used between operators
const
bool
is_used
=
std
::
any_of
(
dots
.
front
(),
dots
.
back
(),
[
&
](
const
auto
&
i
)
{
return
std
::
any_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[
&
](
auto
input
)
{
return
contains
(
dots
,
input
);
});
});
if
(
is_used
)
return
;
std
::
vector
<
instruction_ref
>
args
;
std
::
transform
(
dots
.
begin
(),
dots
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
front
();
});
auto
axis
=
args
.
front
()
->
get_shape
().
lens
().
size
()
-
2
;
auto
last
=
dots
.
back
();
auto
weights
=
last
->
inputs
().
back
();
auto
concat
=
m
.
insert_instruction
(
last
,
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
args
);
auto
fused
=
m
.
insert_instruction
(
last
,
make_op
(
"dot"
),
concat
,
weights
);
int64_t
offset
=
0
;
for
(
auto
arg
:
dots
)
{
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
auto
slice
=
m
.
insert_instruction
(
last
,
make_op
(
"slice"
,
{{
"axes"
,
{
axis
}},
{
"starts"
,
{
offset
}},
{
"ends"
,
{
offset
+
len
}}}),
fused
);
m
.
replace_instruction
(
arg
,
slice
);
offset
+=
len
;
}
}
};
struct
find_div_const
struct
find_div_const
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const
...
@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_conv_dot_horiz_fusion
{},
find_dot_horiz_fusion_weights
{},
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_mul_add
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment