Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b3af63ac
"vscode:/vscode.git/clone" did not exist on "34493a8da4355192bee19795f8f9e408d3ee30a5"
Commit
b3af63ac
authored
Aug 26, 2022
by
Paul
Browse files
Horizontally fuse gemms that share the same weights
parent
1704bb04
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
0 deletions
+55
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+55
-0
No files found.
src/simplify_algebra.cpp
View file @
b3af63ac
...
@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion
...
@@ -831,6 +831,60 @@ struct find_conv_dot_horiz_fusion
}
}
};
};
MIGRAPHX_PRED_MATCHER
(
horiz_dot_weights
,
instruction_ref
ins
)
{
auto
pred
=
[
&
](
auto
name
)
{
return
[
=
](
auto
i
)
{
return
i
->
name
()
==
name
and
i
->
inputs
().
back
()
==
ins
;
};
};
return
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
))
>
1
;
}
struct
find_dot_horiz_fusion_weights
{
auto
matcher
()
const
{
return
horiz_dot_weights
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
std
::
vector
<
instruction_ref
>
dots
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
dots
),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"dot"
and
i
->
inputs
().
back
()
==
ins
;
});
std
::
sort
(
dots
.
begin
(),
dots
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
return
std
::
distance
(
ins
,
i
);
}));
// Check if used between operators
const
bool
is_used
=
std
::
any_of
(
dots
.
front
(),
dots
.
back
(),
[
&
](
const
auto
&
i
)
{
return
std
::
any_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[
&
](
auto
input
)
{
return
contains
(
dots
,
input
);
});
});
if
(
is_used
)
return
;
std
::
vector
<
instruction_ref
>
args
;
std
::
transform
(
dots
.
begin
(),
dots
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
front
();
});
auto
axis
=
args
.
front
()
->
get_shape
().
lens
().
size
()
-
2
;
auto
last
=
dots
.
back
();
auto
weights
=
last
->
inputs
().
back
();
auto
concat
=
m
.
insert_instruction
(
last
,
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
args
);
auto
fused
=
m
.
insert_instruction
(
last
,
make_op
(
"dot"
),
concat
,
weights
);
int64_t
offset
=
0
;
for
(
auto
arg
:
dots
)
{
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
auto
slice
=
m
.
insert_instruction
(
last
,
make_op
(
"slice"
,
{{
"axes"
,
{
axis
}},
{
"starts"
,
{
offset
}},
{
"ends"
,
{
offset
+
len
}}}),
fused
);
m
.
replace_instruction
(
arg
,
slice
);
offset
+=
len
;
}
}
};
struct
find_div_const
struct
find_div_const
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const
...
@@ -1045,6 +1099,7 @@ void simplify_algebra::apply(module& m) const
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_conv_dot_horiz_fusion
{},
find_dot_horiz_fusion_weights
{},
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_mul_add
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment