Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ea97ce52
Unverified
Commit
ea97ce52
authored
Sep 12, 2023
by
ravil-mobile
Committed by
GitHub
Sep 11, 2023
Browse files
[Re-Opened] Added support for standalone mlir-conv (used to be #2110) (#2142)
parent
72b691a1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
31 deletions
+111
-31
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+111
-31
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
ea97ce52
...
@@ -119,6 +119,33 @@ struct mlir_op
...
@@ -119,6 +119,33 @@ struct mlir_op
MIGRAPHX_REGISTER_OP
(
mlir_op
);
MIGRAPHX_REGISTER_OP
(
mlir_op
);
namespace
{
namespace
{
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
{
{
...
@@ -134,7 +161,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
...
@@ -134,7 +161,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return
true
;
return
true
;
}
}
struct
find_mlir_op
struct
find_mlir_
fused_
op
s
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
...
@@ -163,34 +190,6 @@ struct find_mlir_op
...
@@ -163,34 +190,6 @@ struct find_mlir_op
return
ins_map
;
return
ins_map
;
}
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
const
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
// Whitelist supported fusion options, including imposing type constraints
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
// on particular types.
...
@@ -301,14 +300,95 @@ struct find_mlir_op
...
@@ -301,14 +300,95 @@ struct find_mlir_op
}
}
};
};
struct
find_mlir_standalone_convolution_op
{
auto
matcher
()
const
{
return
match
::
name
(
"convolution"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
conv_based_op
=
r
.
result
;
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
conv_based_op
->
inputs
().
begin
(),
conv_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
conv_based_op
);
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
conv_based_op
,
mlir_op
{
conv_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
}
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
bool
is_self_decide
()
{
return
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
).
empty
();
}
bool
is_requested
(
std
::
string_view
option
)
{
assert
(
not
is_self_decide
());
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
bool
is_fusion_enabled
()
{
if
(
is_self_decide
())
{
return
true
;
}
return
is_requested
(
"fused"
);
}
bool
is_standalone_convs_enabled
(
context
*
ctx
)
{
if
(
is_self_decide
())
{
if
(
ctx
==
nullptr
)
{
return
false
;
}
else
{
const
auto
&
device
=
ctx
->
get_current_device
();
const
std
::
string
navi_family
{
"gfx110"
};
return
starts_with
(
device
.
get_gfx_name
(),
navi_family
);
}
}
return
is_requested
(
"convolution"
);
}
}
// namespace
}
// namespace
#endif
#endif
// MIGRAPHX_MLIR
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
match
::
find_matches
(
mpm
,
find_mlir_op
{});
if
(
is_fusion_enabled
())
{
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
}
if
(
is_standalone_convs_enabled
(
this
->
ctx
))
{
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{});
}
#else
#else
(
void
)
mpm
;
(
void
)
mpm
;
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment