Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ad6dcf26
Commit
ad6dcf26
authored
Nov 22, 2023
by
Manupa Karunaratne
Browse files
* add input fusion for v input
parent
f69d828d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
20 deletions
+36
-20
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+36
-20
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
ad6dcf26
...
@@ -64,14 +64,14 @@ struct mlir_op
...
@@ -64,14 +64,14 @@ struct mlir_op
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
{
module_ref
mod
=
mods
[
0
];
check_shapes
{
inputs
,
*
this
}.
packed_or_broadcasted
();
check_shapes
{
inputs
,
*
this
}.
packed_or_broadcasted
();
if
(
mods
.
size
()
!=
1
)
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
for
(
auto
ins
:
iterator_for
(
*
mod
))
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
{
...
@@ -101,6 +101,27 @@ struct mlir_op
...
@@ -101,6 +101,27 @@ struct mlir_op
MIGRAPHX_REGISTER_OP
(
mlir_op
);
MIGRAPHX_REGISTER_OP
(
mlir_op
);
namespace
{
namespace
{
std
::
tuple
<
instruction_ref
,
std
::
vector
<
operation
>>
get_fusable_input_op_stream
(
instruction_ref
lower_input
)
{
instruction_ref
upper_input
=
lower_input
;
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
operation
op
=
upper_input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
upper_input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
upper_input
=
upper_input
->
inputs
().
at
(
0
);
}
return
{
upper_input
,
op_stream
};
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
{
{
...
@@ -109,22 +130,10 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
...
@@ -109,22 +130,10 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
size_t
input_cnt
=
0
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
{
std
::
vector
<
operation
>
op_stream
;
auto
[
upper_input
,
op_stream
]
=
get_fusable_input_op_stream
(
input
);
while
(
contains
(
top_inputs
.
push_back
(
upper_input
);
{
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
operation
op
=
input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
upper_
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
...
@@ -424,9 +433,16 @@ struct find_mlir_standalone_attention_op
...
@@ -424,9 +433,16 @@ struct find_mlir_standalone_attention_op
{
{
return
std
::
make_pair
(
old_ins
,
softmax
);
return
std
::
make_pair
(
old_ins
,
softmax
);
}
}
inputs
.
push_back
(
old_ins
);
auto
[
old_upper_ins
,
op_stream
]
=
get_fusable_input_op_stream
(
old_ins
);
return
std
::
make_pair
(
old_ins
,
instruction_ref
new_upper_ins
=
mm
->
add_parameter
(
"v"
,
old_ins
->
get_shape
()));
mm
->
add_parameter
(
"v"
,
old_upper_ins
->
get_shape
());
instruction_ref
prev_input
=
new_upper_ins
;
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
inputs
.
push_back
(
old_upper_ins
);
return
std
::
make_pair
(
old_ins
,
prev_input
);
});
});
auto
gemm1_a
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
front
()];
auto
gemm1_a
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
front
()];
auto
gemm1_b
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
back
()];
auto
gemm1_b
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
back
()];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment