Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
aeb60bce
Commit
aeb60bce
authored
Jun 13, 2022
by
Paul
Browse files
Correctly add module
parent
2f268bc2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
20 deletions
+28
-20
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+8
-8
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+5
-0
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+2
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+11
-10
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
+2
-2
No files found.
src/include/migraphx/matcher.hpp
View file @
aeb60bce
...
@@ -326,8 +326,8 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -326,8 +326,8 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module
template
<
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
m
od
ule
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
void
find_matches
(
M
od
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
const
...
@@ -340,13 +340,13 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
...
@@ -340,13 +340,13 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
return
;
return
;
if
(
trace
>
1
)
if
(
trace
>
1
)
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
auto
r
=
match_instruction
(
mod
,
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
get_module
(
mod
)
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
mod
.
end
())
if
(
r
.
result
==
get_module
(
mod
)
.
end
())
return
;
return
;
if
(
trace
>
0
)
if
(
trace
>
0
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
mod
.
debug_print
(
ins
);
get_module
(
mod
)
.
debug_print
(
ins
);
}
}
m
.
apply
(
mod
,
r
);
m
.
apply
(
mod
,
r
);
match
=
true
;
match
=
true
;
...
@@ -355,10 +355,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
...
@@ -355,10 +355,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
}
}
/// Find matches in a module
/// Find matches in a module
template
<
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
m
od
ule
&
mod
,
Ms
&&
...
ms
)
void
find_matches
(
M
od
&
mod
,
Ms
&&
...
ms
)
{
{
for
(
auto
ins
:
iterator_for
(
mod
))
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)
))
{
{
find_matches
(
mod
,
ins
,
ms
...);
find_matches
(
mod
,
ins
,
ms
...);
}
}
...
...
src/include/migraphx/module.hpp
View file @
aeb60bce
...
@@ -178,6 +178,11 @@ struct module
...
@@ -178,6 +178,11 @@ struct module
std
::
unique_ptr
<
module_impl
>
impl
;
std
::
unique_ptr
<
module_impl
>
impl
;
};
};
inline
module
&
get_module
(
module
&
m
)
{
return
m
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/pass_manager.hpp
View file @
aeb60bce
...
@@ -21,6 +21,8 @@ struct module_pass_manager
...
@@ -21,6 +21,8 @@ struct module_pass_manager
virtual
~
module_pass_manager
()
{}
virtual
~
module_pass_manager
()
{}
};
};
module
&
get_module
(
module_pass_manager
&
mpm
);
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
aeb60bce
...
@@ -47,7 +47,7 @@ struct find_conv_pointwise
...
@@ -47,7 +47,7 @@ struct find_conv_pointwise
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
convolution
.
bind
(
"x"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
convolution
.
bind
(
"x"
)));
}
}
void
apply
(
module
&
m
,
match
::
matcher_result
r
)
const
void
apply
(
module
_pass_manager
&
mp
m
,
match
::
matcher_result
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
conv_ins
=
r
.
instructions
[
"convolution"
];
auto
conv_ins
=
r
.
instructions
[
"convolution"
];
...
@@ -61,13 +61,14 @@ struct find_conv_pointwise
...
@@ -61,13 +61,14 @@ struct find_conv_pointwise
}))
}))
return
;
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
module
mm
{};
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
auto
x
=
mm
.
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
auto
x
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
conv_ins
->
inputs
().
at
(
0
)
->
get_shape
());
conv_ins
->
inputs
().
at
(
0
)
->
get_shape
());
auto
w
=
mm
.
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
auto
w
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
conv_ins
->
inputs
().
at
(
1
)
->
get_shape
());
conv_ins
->
inputs
().
at
(
1
)
->
get_shape
());
auto
conv
=
mm
.
add_instruction
(
conv_ins
->
get_operator
(),
{
x
,
w
});
auto
conv
=
mm
->
add_instruction
(
conv_ins
->
get_operator
(),
{
x
,
w
});
std
::
transform
(
names
.
begin
(),
std
::
transform
(
names
.
begin
(),
names
.
end
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
ins
->
inputs
().
begin
(),
...
@@ -76,9 +77,9 @@ struct find_conv_pointwise
...
@@ -76,9 +77,9 @@ struct find_conv_pointwise
if
(
input
==
x_ins
)
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
conv
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
conv
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
.
add_parameter
(
name
,
input
->
get_shape
()));
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
});
mm
.
add_return
(
mm
.
insert_module_instructions
(
mm
.
end
(),
pm
,
param_map
));
mm
->
add_return
(
mm
->
insert_module_instructions
(
mm
->
end
(),
pm
,
param_map
));
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
std
::
copy_if
(
ins
->
inputs
().
begin
(),
...
@@ -86,13 +87,13 @@ struct find_conv_pointwise
...
@@ -86,13 +87,13 @@ struct find_conv_pointwise
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
conv_ins
;
});
[
&
](
auto
input
)
{
return
input
!=
conv_ins
;
});
inputs
.
insert
(
inputs
.
end
(),
conv_ins
->
inputs
().
begin
(),
conv_ins
->
inputs
().
end
());
inputs
.
insert
(
inputs
.
end
(),
conv_ins
->
inputs
().
begin
(),
conv_ins
->
inputs
().
end
());
m
.
replace_instruction
(
m
pm
.
get_module
()
.
replace_instruction
(
ins
,
mlir_conv
{
conv_ins
->
get_operator
()},
inputs
,
ins
->
module_inputs
()
);
ins
,
mlir_conv
{
conv_ins
->
get_operator
()},
inputs
,
{
mm
}
);
}
}
};
};
}
// namespace
}
// namespace
void
fuse_mlir
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_conv_pointwise
{});
}
void
fuse_mlir
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
match
::
find_matches
(
m
pm
,
find_conv_pointwise
{});
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
View file @
aeb60bce
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
namespace
gpu
{
...
@@ -15,7 +15,7 @@ struct fuse_mlir
...
@@ -15,7 +15,7 @@ struct fuse_mlir
{
{
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::fuse_mlir"
;
}
std
::
string
name
()
const
{
return
"gpu::fuse_mlir"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
m
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment