Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1a9924c0
Commit
1a9924c0
authored
Aug 23, 2022
by
turneram
Browse files
Formatting
parent
540e262f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
43 deletions
+51
-43
src/rewrite_batched_gemms.cpp
src/rewrite_batched_gemms.cpp
+51
-43
No files found.
src/rewrite_batched_gemms.cpp
View file @
1a9924c0
...
...
@@ -42,95 +42,103 @@ void rewrite_batched_gemms::apply(module& m) const
{
if
(
ins
->
name
()
!=
"dot"
)
continue
;
//std::cout << "Rewrite Batched GEMMS" << std::endl;
//ins->debug_print();
//m.debug_print();
//return;
//
std::cout << "Rewrite Batched GEMMS" << std::endl;
//
ins->debug_print();
//
m.debug_print();
//
return;
auto
inputs
=
ins
->
inputs
();
auto
a_mat
=
inputs
.
front
();
auto
b_mat
=
inputs
.
at
(
1
);
//.back()?
auto
a_mat
=
inputs
.
front
();
auto
b_mat
=
inputs
.
at
(
1
);
//.back()?
auto
a_lens
=
a_mat
->
get_shape
().
lens
();
auto
b_lens
=
b_mat
->
get_shape
().
lens
();
if
(
a_lens
.
size
()
>
2
)
if
(
a_lens
.
size
()
>
2
)
{
auto
batch_size
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
reshape_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
a_lens
[
a_lens
.
size
()
-
2
],
a_lens
.
back
()}}}),
a_mat
);
//reshape_a->debug_print();
//std::cout << b_mat->get_operator().name() << std::endl;
auto
reshape_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
a_lens
[
a_lens
.
size
()
-
2
],
a_lens
.
back
()}}}),
a_mat
);
// reshape_a->debug_print();
// std::cout << b_mat->get_operator().name() << std::endl;
instruction_ref
unbc_b
;
if
(
b_mat
->
get_operator
().
name
()
==
"concat"
)
if
(
b_mat
->
get_operator
().
name
()
==
"concat"
)
{
auto
concat_inputs
=
b_mat
->
inputs
();
std
::
vector
<
instruction_ref
>
concat_lits
;
int
concat_axis
=
1
;
int
concat_axis
=
1
;
bool
return_early
=
false
;
for
(
auto
c
:
concat_inputs
)
for
(
auto
c
:
concat_inputs
)
{
if
(
c
->
get_operator
().
name
()
==
"contiguous"
)
if
(
c
->
get_operator
().
name
()
==
"contiguous"
)
c
=
c
->
inputs
().
front
();
//std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " << c->get_shape().broadcasted() <<std::endl;
if
(
c
->
get_shape
().
broadcasted
())
// std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " <<
// c->get_shape().broadcasted() <<std::endl;
if
(
c
->
get_shape
().
broadcasted
())
{
//std::cout << c->inputs().front()->get_operator() <<std::endl;
auto
lit
=
c
->
inputs
().
front
();
//
std::cout << c->inputs().front()->get_operator() <<std::endl;
auto
lit
=
c
->
inputs
().
front
();
auto
lit_dims
=
lit
->
get_shape
().
lens
().
size
();
if
(
lit_dims
>
2
)
if
(
lit_dims
>
2
)
return_early
=
true
;
concat_axis
=
lit_dims
-
1
;
concat_lits
.
push_back
(
lit
);
}
}
if
(
return_early
)
if
(
return_early
)
continue
;
unbc_b
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
concat_lits
);
unbc_b
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
concat_lits
);
}
else
if
(
b_mat
->
get_operator
().
name
()
==
"contiguous"
)
else
if
(
b_mat
->
get_operator
().
name
()
==
"contiguous"
)
{
//std::cout << "Contiguous B" <<std::endl;
//b_mat->debug_print();
//
std::cout << "Contiguous B" <<std::endl;
//
b_mat->debug_print();
auto
b_input
=
b_mat
->
inputs
().
front
();
//std::cout << b_input->get_operator().name() << ", " << b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl;
if
(
b_input
->
get_shape
().
broadcasted
())
// std::cout << b_input->get_operator().name() << ", " <<
// b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl;
if
(
b_input
->
get_shape
().
broadcasted
())
{
auto
lit
=
b_input
->
inputs
().
front
();
auto
lit
=
b_input
->
inputs
().
front
();
auto
lit_dims
=
lit
->
get_shape
().
lens
().
size
();
if
(
lit_dims
>
2
)
if
(
lit_dims
>
2
)
continue
;
unbc_b
=
lit
;
//unbc_b->debug_print();
//
unbc_b->debug_print();
}
else
continue
;
}
else
else
{
//std::cout << "Else" << std::endl;
//
std::cout << "Else" << std::endl;
continue
;
}
auto
new_dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
reshape_a
,
unbc_b
);
auto
new_dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
reshape_a
,
unbc_b
);
auto
out_lens
=
a_lens
;
out_lens
.
pop_back
();
out_lens
.
push_back
(
b_lens
.
back
());
//std::cout << std::next(ins)->get_operator().name() << std::endl;
//
std::cout << std::next(ins)->get_operator().name() << std::endl;
auto
next_ins
=
std
::
next
(
ins
);
if
(
next_ins
->
get_operator
().
name
()
==
"add"
)
if
(
next_ins
->
get_operator
().
name
()
==
"add"
)
{
auto
add_in
=
next_ins
->
inputs
().
back
()
==
ins
?
next_ins
->
inputs
().
front
()
:
next_ins
->
inputs
().
back
();
//add_in->debug_print();
auto
reshape_add
=
m
.
insert_instruction
(
next_ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
a_lens
[
a_lens
.
size
()
-
2
],
b_lens
.
back
()}}}),
add_in
);
auto
add_in
=
next_ins
->
inputs
().
back
()
==
ins
?
next_ins
->
inputs
().
front
()
:
next_ins
->
inputs
().
back
();
// add_in->debug_print();
auto
reshape_add
=
m
.
insert_instruction
(
next_ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
a_lens
[
a_lens
.
size
()
-
2
],
b_lens
.
back
()}}}),
add_in
);
new_dot
=
m
.
replace_instruction
(
next_ins
,
make_op
(
"add"
),
reshape_add
,
new_dot
);
}
//std::cout << "here" <<std::endl;
//
std::cout << "here" <<std::endl;
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
out_lens
}}),
new_dot
);
}
//m.debug_print();
// m.debug_print();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment