Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
25d6b2e2
Commit
25d6b2e2
authored
Oct 12, 2023
by
Manupa Karunaratne
Browse files
[WIP] factoring out pointwise folding
* some skeleton code to handle attention patterns
parent
a3cf9951
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
199 additions
and
35 deletions
+199
-35
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+199
-35
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
25d6b2e2
...
@@ -164,18 +164,8 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
...
@@ -164,18 +164,8 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return
true
;
return
true
;
}
}
struct
find_mlir_fused_ops
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
{
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
for
(
auto
ins
:
iterator_for
(
*
pm
))
...
@@ -193,6 +183,37 @@ struct find_mlir_fused_ops
...
@@ -193,6 +183,37 @@ struct find_mlir_fused_ops
return
ins_map
;
return
ins_map
;
}
}
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
){
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
ins_map
.
count
(
input
))
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
ins_map
.
at
(
input
));
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
parent_mod
->
add_parameter
(
name
,
input
->
get_shape
()));
});
return
parent_mod
->
insert_instructions
(
parent_mod
->
end
(),
pm
,
param_map
);
}
struct
find_mlir_fused_ops
{
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
),
is_mlir_conv
())
.
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
// Whitelist supported fusion options, including imposing type constraints
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
// on particular types.
...
@@ -260,7 +281,7 @@ struct find_mlir_fused_ops
...
@@ -260,7 +281,7 @@ struct find_mlir_fused_ops
return
false
;
return
false
;
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
rewrite
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
gemm_based_op
=
r
.
instructions
[
"gemm_based_op"
];
auto
gemm_based_op
=
r
.
instructions
[
"gemm_based_op"
];
...
@@ -276,20 +297,21 @@ struct find_mlir_fused_ops
...
@@ -276,20 +297,21 @@ struct find_mlir_fused_ops
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
//
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
//
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
std
::
transform
(
names
.
begin
(),
// std::transform(names.begin(),
names
.
end
(),
// names.end(),
ins
->
inputs
().
begin
(),
// ins->inputs().begin(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
// std::inserter(param_map, param_map.end()),
[
&
,
&
anchor
=
anchor_op
](
auto
name
,
auto
input
)
{
// [&, &anchor = anchor_op](auto name, auto input) {
if
(
input
==
x_ins
)
// if(input == x_ins)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor
);
// return std::make_pair(pm->get_parameter(name), anchor);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
// return std::make_pair(pm->get_parameter(name),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
// mm->add_parameter(name, input->get_shape()));
});
// });
mm
->
add_return
(
mm
->
insert_instructions
(
mm
->
end
(),
pm
,
param_map
));
// mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
mm
->
add_return
(
fold_pointwise_mod
(
ins
,
mm
,
{{
x_ins
,
anchor_op
}}));
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
std
::
copy_if
(
ins
->
inputs
().
begin
(),
...
@@ -300,10 +322,70 @@ struct find_mlir_fused_ops
...
@@ -300,10 +322,70 @@ struct find_mlir_fused_ops
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
rewrite
(
mpm
,
r
);
}
};
struct
find_mlir_attention_fused_ops
:
public
find_mlir_fused_ops
{
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
))).
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
is_mlir_attention
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
// Check the pointwise mod only contains a single mul
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_pm
=
r
.
instructions
[
"scale"
];
bool
found_mul
=
false
;
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
()){
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
())){
continue
;
}
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
){
found_mul
=
true
;
continue
;
}
return
;
}
}
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
rewrite
(
mpm
,
r
);
}
};
};
struct
find_mlir_standalone_op
struct
find_mlir_standalone_op
{
{
void
rewrite
(
module_pass_manager
&
mpm
,
instruction_ref
top_ins
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
top_ins
,
mlir_op
{
top_ins
->
get_operator
()},
top_inputs
,
{
mm
});
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
conv_based_op
=
r
.
result
;
auto
conv_based_op
=
r
.
result
;
...
@@ -314,14 +396,7 @@ struct find_mlir_standalone_op
...
@@ -314,14 +396,7 @@ struct find_mlir_standalone_op
i
->
get_shape
().
type
());
i
->
get_shape
().
type
());
}))
}))
return
;
return
;
rewrite
(
mpm
,
conv_based_op
);
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
conv_based_op
);
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
conv_based_op
,
mlir_op
{
conv_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
}
}
};
};
...
@@ -335,6 +410,93 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
...
@@ -335,6 +410,93 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
auto
matcher
()
const
{
return
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
));
}
auto
matcher
()
const
{
return
match
::
any_of
(
match
::
name
(
"dot"
),
match
::
name
(
"quant_dot"
));
}
};
};
struct
find_mlir_standalone_attention_op
:
find_mlir_standalone_op
{
void
insert_to_map
(
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
,
instruction_ref
old_ins
,
instruction_ref
new_ins
)
const
{
if
(
ins_map
.
count
(
new_ins
)){
new_ins
=
ins_map
[
new_ins
];
}
if
(
!
ins_map
.
count
(
old_ins
)){
ins_map
[
old_ins
]
=
new_ins
;
}
}
instruction_ref
get_from_map
(
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
,
instruction_ref
ins
)
const
{
if
(
ins_map
.
count
(
ins
)){
return
ins_map
[
ins
];
}
return
ins
;
}
void
rewrite
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
std
::
vector
<
instruction_ref
>
inputs
;
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
top_ins
=
r
.
instructions
[
"top_dot"
];
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
insert_to_map
(
ins_map
,
top_ins
,
new_top_ins
);
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_ins
=
r
.
instructions
[
"scale"
];
new_top_ins
=
fold_pointwise_mod
(
scale_ins
,
mm
,
ins_map
)[
0
];
std
::
copy_if
(
scale_ins
->
inputs
().
begin
(),
scale_ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
}
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
insert_to_map
(
ins_map
,
r
.
instructions
[
"softmax"
],
softmax
);
auto
bottom_dot_a
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
front
());
auto
bottom_dot_b
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
back
());
auto
new_bottom_dot
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
bottom_dot_a
,
bottom_dot_b
});
mm
->
add_return
({
new_bottom_dot
});
inputs
.
insert
(
inputs
.
end
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
mpm
.
get_module
().
replace_instruction
(
top_ins
,
mlir_op
{
new_bottom_dot
->
get_operator
()},
inputs
,
{
mm
});
}
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
))).
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
).
bind
(
"softmax"
))).
bind
(
"bottom_dot"
);
return
is_mlir_attention
;
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
top_dot
=
r
.
instructions
[
"top_dot"
];
// Check the pointwise mod only contains a single mul
std
::
cerr
<<
"standalone attention found!
\n
"
;
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_pm
=
r
.
instructions
[
"scale"
];
bool
found_mul
=
false
;
for
(
const
auto
&
scale_ins
:
*
scale_pm
->
module_inputs
().
front
()){
if
(
contains
({
"@param"
,
"@literal"
,
"@return"
},
scale_ins
.
name
())){
continue
;
}
if
(
scale_ins
.
name
()
==
"mul"
&&
!
found_mul
){
found_mul
=
true
;
continue
;
}
std
::
cerr
<<
"standalone attention scale not compatible!
\n
"
;
return
;
}
}
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
top_dot
->
inputs
().
begin
(),
top_dot
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
})){
std
::
cerr
<<
"standalone attention dtype not compatible!
\n
"
;
return
;
}
rewrite
(
mpm
,
r
);
}
};
/**
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
* only specific MLIR operations.
...
@@ -393,9 +555,11 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
...
@@ -393,9 +555,11 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
{
{
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
}
}
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{});
if
(
is_enabled
(
"convolution"
,
this
->
ctx
))
if
(
is_enabled
(
"convolution"
,
this
->
ctx
))
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment