Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
59a53257
Commit
59a53257
authored
Oct 18, 2023
by
Manupa Karunaratne
Browse files
* clang-format
parent
83b9164b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
136 additions
and
110 deletions
+136
-110
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+136
-110
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
59a53257
...
...
@@ -206,40 +206,44 @@ auto is_mlir_conv(mlir_mode mode)
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
if
(
ins
->
name
()
!=
"@literal"
)
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
continue
;
}
return
ins_map
;
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
){
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
)
{
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
ins_map
.
count
(
input
))
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
ins_map
.
at
(
input
));
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
parent_mod
->
add_parameter
(
name
,
input
->
get_shape
()));
});
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
ins_map
.
count
(
input
))
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
ins_map
.
at
(
input
));
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
parent_mod
->
add_parameter
(
name
,
input
->
get_shape
()));
});
return
parent_mod
->
insert_instructions
(
parent_mod
->
end
(),
pm
,
param_map
);
}
...
...
@@ -252,10 +256,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
...
...
@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return
false
;
}
struct
find_mlir_fused_ops
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
...
...
@@ -354,8 +356,8 @@ struct find_mlir_fused_ops
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
ins
=
r
.
result
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
...
...
@@ -373,13 +375,13 @@ struct find_mlir_standalone_op
void
rewrite
(
module_pass_manager
&
mpm
,
instruction_ref
top_ins
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
top_ins
,
mlir_op
{
top_ins
->
get_operator
()},
top_inputs
,
{
mm
});
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
top_ins
,
mlir_op
{
top_ins
->
get_operator
()},
top_inputs
,
{
mm
});
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -405,77 +407,94 @@ struct find_mlir_standalone_attention_op
mlir_mode
mode
=
mlir_mode
::
none
;
void
rewrite
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
std
::
vector
<
instruction_ref
>
inputs
;
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
top_ins
=
r
.
instructions
[
"gemm0"
];
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
ins_map
[
top_ins
]
=
new_top_ins
;
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_ins
=
r
.
instructions
[
"scale"
];
new_top_ins
=
fold_pointwise_mod
(
scale_ins
,
mm
,
ins_map
)[
0
];
std
::
copy_if
(
scale_ins
->
inputs
().
begin
(),
scale_ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
}
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
std
::
transform
(
r
.
instructions
[
"gemm1"
]
->
inputs
().
begin
(),
r
.
instructions
[
"gemm1"
]
->
inputs
().
end
(),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
[
&
](
auto
old_ins
)
{
if
(
old_ins
==
r
.
instructions
[
"softmax"
]){
return
std
::
make_pair
(
old_ins
,
softmax
);
}
inputs
.
push_back
(
old_ins
);
return
std
::
make_pair
(
old_ins
,
mm
->
add_parameter
(
"v"
,
old_ins
->
get_shape
()));
});
auto
gemm1_a
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
front
()];
auto
gemm1_b
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
back
()];
auto
new_gemm1
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
gemm1_a
,
gemm1_b
});
ins_map
[
r
.
instructions
[
"gemm1"
]]
=
new_gemm1
;
auto
ins_to_replace
=
new_gemm1
;
auto
ins_to_be_replaced
=
r
.
instructions
[
"gemm1"
];
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
()){
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
r
.
instructions
[
"gemm1"
];
});
ins_to_be_replaced
=
r
.
instructions
[
"trailing_pm"
];
}
mm
->
add_return
({
ins_to_replace
});
mpm
.
get_module
().
replace_instruction
(
ins_to_be_replaced
,
mlir_op
{
new_gemm1
->
get_operator
()},
inputs
,
{
mm
});
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
std
::
vector
<
instruction_ref
>
inputs
;
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
top_ins
=
r
.
instructions
[
"gemm0"
];
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
ins_map
[
top_ins
]
=
new_top_ins
;
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
())
{
auto
scale_ins
=
r
.
instructions
[
"scale"
];
new_top_ins
=
fold_pointwise_mod
(
scale_ins
,
mm
,
ins_map
)[
0
];
std
::
copy_if
(
scale_ins
->
inputs
().
begin
(),
scale_ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
}
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
std
::
transform
(
r
.
instructions
[
"gemm1"
]
->
inputs
().
begin
(),
r
.
instructions
[
"gemm1"
]
->
inputs
().
end
(),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
[
&
](
auto
old_ins
)
{
if
(
old_ins
==
r
.
instructions
[
"softmax"
])
{
return
std
::
make_pair
(
old_ins
,
softmax
);
}
inputs
.
push_back
(
old_ins
);
return
std
::
make_pair
(
old_ins
,
mm
->
add_parameter
(
"v"
,
old_ins
->
get_shape
()));
});
auto
gemm1_a
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
front
()];
auto
gemm1_b
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
back
()];
auto
new_gemm1
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
gemm1_a
,
gemm1_b
});
ins_map
[
r
.
instructions
[
"gemm1"
]]
=
new_gemm1
;
auto
ins_to_replace
=
new_gemm1
;
auto
ins_to_be_replaced
=
r
.
instructions
[
"gemm1"
];
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
())
{
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
r
.
instructions
[
"gemm1"
];
});
ins_to_be_replaced
=
r
.
instructions
[
"trailing_pm"
];
}
mm
->
add_return
({
ins_to_replace
});
mpm
.
get_module
().
replace_instruction
(
ins_to_be_replaced
,
mlir_op
{
new_gemm1
->
get_operator
()},
inputs
,
{
mm
});
}
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
))).
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
))).
bind
(
"gemm1"
);
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"gemm0"
)))
.
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
)))
.
bind
(
"gemm1"
);
return
is_mlir_attention
;
}
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
// We are only enabling attention
// in the highest enablement mode for now
if
(
mode
!=
mlir_mode
::
all
){
if
(
mode
!=
mlir_mode
::
all
)
{
return
false
;
}
auto
gemm0
=
r
.
instructions
[
"gemm0"
];
// Check the pointwise mod only contains a single mul
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_pm
=
r
.
instructions
[
"scale"
];
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
())
{
auto
scale_pm
=
r
.
instructions
[
"scale"
];
bool
found_mul
=
false
;
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
()){
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
())){
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
())
{
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
()))
{
continue
;
}
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
){
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
)
{
found_mul
=
true
;
continue
;
}
...
...
@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
})){
}))
{
return
false
;
}
return
true
;
...
...
@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
if
(
!
check
(
r
)){
if
(
!
check
(
r
))
{
return
;
}
rewrite
(
mpm
,
r
);
...
...
@@ -504,17 +525,22 @@ struct find_mlir_standalone_attention_op
struct
find_mlir_attention_fused_ops
:
public
find_mlir_standalone_attention_op
{
auto
matcher
()
const
{
auto
matcher
()
const
{
auto
standalone_matcher
=
find_mlir_standalone_attention_op
::
matcher
();
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));;
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));
;
}
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
if
(
!
find_mlir_standalone_attention_op
::
check
(
r
)){
bool
check
(
const
match
::
matcher_result
&
r
)
const
{
if
(
!
find_mlir_standalone_attention_op
::
check
(
r
))
{
return
false
;
}
auto
trailing_pm_ins
=
r
.
instructions
[
"trailing_pm"
];
// input after contiguous
auto
*
trailing_pm
=
trailing_pm_ins
->
module_inputs
().
front
();
auto
*
trailing_pm
=
trailing_pm_ins
->
module_inputs
().
front
();
// Whitelist pointwise operators.
if
(
std
::
any_of
(
trailing_pm
->
begin
(),
trailing_pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
...
...
@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
if
(
!
check
(
r
)){
if
(
!
check
(
r
))
{
return
;
}
rewrite
(
mpm
,
r
);
...
...
@@ -547,7 +574,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
{
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
...
...
@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
// Attention offloads; default disabled
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{
get_mode
(
"attention"
,
mlir_mode
::
none
)});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment