Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
59a53257
"vscode:/vscode.git/clone" did not exist on "0a86e3690342cb1af94ea6a1ab229fa2e1300231"
Commit
59a53257
authored
Oct 18, 2023
by
Manupa Karunaratne
Browse files
* clang-format
parent
83b9164b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
136 additions
and
110 deletions
+136
-110
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+136
-110
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
59a53257
...
@@ -206,8 +206,8 @@ auto is_mlir_conv(mlir_mode mode)
...
@@ -206,8 +206,8 @@ auto is_mlir_conv(mlir_mode mode)
}
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
{
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
{
...
@@ -217,20 +217,24 @@ std::unordered_map<instruction_ref, instruction_ref>
...
@@ -217,20 +217,24 @@ std::unordered_map<instruction_ref, instruction_ref>
}
}
literal
r
=
ins
->
get_literal
();
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
instruction_ref
mbcast
=
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
ins_map
[
ins
]
=
mbcast
;
}
}
return
ins_map
;
return
ins_map
;
}
}
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
){
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
)
{
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
std
::
transform
(
names
.
begin
(),
names
.
end
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
std
::
inserter
(
param_map
,
param_map
.
end
()),
...
@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
...
@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return
false
;
return
false
;
}
}
struct
find_mlir_fused_ops
struct
find_mlir_fused_ops
{
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
conv_mode
=
mlir_mode
::
none
;
...
@@ -415,7 +417,8 @@ struct find_mlir_standalone_attention_op
...
@@ -415,7 +417,8 @@ struct find_mlir_standalone_attention_op
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
ins_map
[
top_ins
]
=
new_top_ins
;
ins_map
[
top_ins
]
=
new_top_ins
;
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
())
{
auto
scale_ins
=
r
.
instructions
[
"scale"
];
auto
scale_ins
=
r
.
instructions
[
"scale"
];
new_top_ins
=
fold_pointwise_mod
(
scale_ins
,
mm
,
ins_map
)[
0
];
new_top_ins
=
fold_pointwise_mod
(
scale_ins
,
mm
,
ins_map
)[
0
];
std
::
copy_if
(
scale_ins
->
inputs
().
begin
(),
std
::
copy_if
(
scale_ins
->
inputs
().
begin
(),
...
@@ -428,7 +431,8 @@ struct find_mlir_standalone_attention_op
...
@@ -428,7 +431,8 @@ struct find_mlir_standalone_attention_op
r
.
instructions
[
"gemm1"
]
->
inputs
().
end
(),
r
.
instructions
[
"gemm1"
]
->
inputs
().
end
(),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
[
&
](
auto
old_ins
)
{
[
&
](
auto
old_ins
)
{
if
(
old_ins
==
r
.
instructions
[
"softmax"
]){
if
(
old_ins
==
r
.
instructions
[
"softmax"
])
{
return
std
::
make_pair
(
old_ins
,
softmax
);
return
std
::
make_pair
(
old_ins
,
softmax
);
}
}
inputs
.
push_back
(
old_ins
);
inputs
.
push_back
(
old_ins
);
...
@@ -441,7 +445,8 @@ struct find_mlir_standalone_attention_op
...
@@ -441,7 +445,8 @@ struct find_mlir_standalone_attention_op
ins_map
[
r
.
instructions
[
"gemm1"
]]
=
new_gemm1
;
ins_map
[
r
.
instructions
[
"gemm1"
]]
=
new_gemm1
;
auto
ins_to_replace
=
new_gemm1
;
auto
ins_to_replace
=
new_gemm1
;
auto
ins_to_be_replaced
=
r
.
instructions
[
"gemm1"
];
auto
ins_to_be_replaced
=
r
.
instructions
[
"gemm1"
];
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
()){
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
())
{
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
...
@@ -454,28 +459,42 @@ struct find_mlir_standalone_attention_op
...
@@ -454,28 +459,42 @@ struct find_mlir_standalone_attention_op
ins_to_be_replaced
,
mlir_op
{
new_gemm1
->
get_operator
()},
inputs
,
{
mm
});
ins_to_be_replaced
,
mlir_op
{
new_gemm1
->
get_operator
()},
inputs
,
{
mm
});
}
}
auto
matcher
()
const
{
auto
matcher
()
const
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
))).
bind
(
"scale"
));
{
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
))).
bind
(
"gemm1"
);
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
)))
.
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
)))
.
bind
(
"gemm1"
);
return
is_mlir_attention
;
return
is_mlir_attention
;
}
}
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
// We are only enabling attention
// We are only enabling attention
// in the highest enablement mode for now
// in the highest enablement mode for now
if
(
mode
!=
mlir_mode
::
all
){
if
(
mode
!=
mlir_mode
::
all
)
{
return
false
;
return
false
;
}
}
auto
gemm0
=
r
.
instructions
[
"gemm0"
];
auto
gemm0
=
r
.
instructions
[
"gemm0"
];
// Check the pointwise mod only contains a single mul
// Check the pointwise mod only contains a single mul
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
())
{
auto
scale_pm
=
r
.
instructions
[
"scale"
];
auto
scale_pm
=
r
.
instructions
[
"scale"
];
bool
found_mul
=
false
;
bool
found_mul
=
false
;
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
()){
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
())
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
())){
{
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
()))
{
continue
;
continue
;
}
}
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
){
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
)
{
found_mul
=
true
;
found_mul
=
true
;
continue
;
continue
;
}
}
...
@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op
...
@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op
return
not
contains
(
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
i
->
get_shape
().
type
());
})){
}))
{
return
false
;
return
false
;
}
}
return
true
;
return
true
;
...
@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op
...
@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
if
(
!
check
(
r
)){
if
(
!
check
(
r
))
{
return
;
return
;
}
}
rewrite
(
mpm
,
r
);
rewrite
(
mpm
,
r
);
...
@@ -504,13 +525,18 @@ struct find_mlir_standalone_attention_op
...
@@ -504,13 +525,18 @@ struct find_mlir_standalone_attention_op
struct
find_mlir_attention_fused_ops
:
public
find_mlir_standalone_attention_op
struct
find_mlir_attention_fused_ops
:
public
find_mlir_standalone_attention_op
{
{
auto
matcher
()
const
{
auto
matcher
()
const
{
auto
standalone_matcher
=
find_mlir_standalone_attention_op
::
matcher
();
auto
standalone_matcher
=
find_mlir_standalone_attention_op
::
matcher
();
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));;
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));
;
}
}
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
bool
check
(
const
match
::
matcher_result
&
r
)
const
if
(
!
find_mlir_standalone_attention_op
::
check
(
r
)){
{
if
(
!
find_mlir_standalone_attention_op
::
check
(
r
))
{
return
false
;
return
false
;
}
}
auto
trailing_pm_ins
=
r
.
instructions
[
"trailing_pm"
];
// input after contiguous
auto
trailing_pm_ins
=
r
.
instructions
[
"trailing_pm"
];
// input after contiguous
...
@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
...
@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
if
(
!
check
(
r
)){
if
(
!
check
(
r
))
{
return
;
return
;
}
}
rewrite
(
mpm
,
r
);
rewrite
(
mpm
,
r
);
...
@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
...
@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
// Attention offloads; default disabled
// Attention offloads; default disabled
match
::
find_matches
(
mpm
,
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
find_mlir_attention_fused_ops
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
match
::
find_matches
(
mpm
,
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
find_mlir_standalone_attention_op
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment