Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a851f699
Commit
a851f699
authored
Apr 30, 2021
by
Paul
Browse files
Hash modules for quicker lookup of modules
parent
56584fa2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
64 deletions
+72
-64
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+14
-0
src/program.cpp
src/program.cpp
+54
-60
test/eval_test.cpp
test/eval_test.cpp
+4
-4
No files found.
src/include/migraphx/algorithm.hpp
View file @
a851f699
...
...
@@ -7,6 +7,20 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
Output
,
class
Predicate
,
class
F
>
void
transform_if
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
,
F
f
)
{
while
(
start
!=
last
)
{
if
(
pred
(
*
start
))
{
*
out
=
f
(
*
start
);
++
out
;
}
++
start
;
}
}
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
{
...
...
src/program.cpp
View file @
a851f699
...
...
@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
...
...
@@ -26,13 +27,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program_impl
{
// A map is used to keep references to modules of the program
// all the modules are store in the depth-first order
std
::
list
<
module
>
modules
;
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
ctx
;
std
::
string
target_name
;
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
impl
->
modules
.
push_back
({
"main"
}
);
}
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
program
::
program
(
program
&&
)
noexcept
=
default
;
program
::~
program
()
noexcept
=
default
;
...
...
@@ -67,9 +67,8 @@ void program::assign(const program& p)
std
::
unordered_map
<
module_ref
,
module_ref
>
mod_map
;
std
::
transform
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
p
.
impl
->
modules
.
begin
(),
std
::
inserter
(
mod_map
,
mod_map
.
begin
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
&
y
,
&
x
);
});
[
&
](
auto
&&
x
p
)
{
return
std
::
make_pair
(
&
p
.
impl
->
modules
.
at
(
xp
.
first
),
&
xp
.
second
);
});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
&&
pp
:
mod_map
)
...
...
@@ -86,7 +85,7 @@ void program::assign(const program& p)
// Update all references from all modules
for
(
auto
&&
mp
:
impl
->
modules
)
{
for
(
auto
ins
:
iterator_for
(
mp
))
for
(
auto
ins
:
iterator_for
(
mp
.
second
))
instruction
::
replace_refs
(
ins
,
ins_map
,
mod_map
);
}
}
...
...
@@ -316,14 +315,14 @@ value program::to_value() const
if
(
not
this
->
impl
->
target_name
.
empty
())
result
[
"context"
]
=
this
->
impl
->
ctx
.
to_value
();
value
module_vals
=
value
::
array
{};
value
module_vals
=
value
::
object
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
mod
:
this
->
get_
modules
()
)
{
value
mod_val
;
value
nodes
;
mod_val
[
"name"
]
=
mod
.
name
();
names
=
mod
.
print
(
mod_val
[
"name"
]
=
mod
->
name
();
names
=
mod
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
value
node
;
node
[
"output"
]
=
ins_names
.
at
(
ins
);
...
...
@@ -358,7 +357,7 @@ value program::to_value() const
names
);
mod_val
[
"nodes"
]
=
nodes
;
module_vals
.
push_back
(
mod_val
)
;
module_vals
[
mod
->
name
()]
=
mod_val
;
}
result
[
"modules"
]
=
module_vals
;
...
...
@@ -371,12 +370,7 @@ static void mod_from_val(module_ref mod,
std
::
unordered_map
<
std
::
string
,
instruction_ref
>&
instructions
,
const
std
::
unordered_map
<
std
::
string
,
module_ref
>&
map_mods
)
{
const
auto
*
it
=
std
::
find_if
(
v
.
begin
(),
v
.
end
(),
[
&
](
auto
&
mv
)
{
return
mv
.
at
(
"name"
).
template
to
<
std
::
string
>()
==
mod
->
name
();
});
assert
(
it
!=
v
.
end
());
const
auto
&
module_val
=
*
it
;
const
auto
&
module_val
=
v
.
at
(
mod
->
name
());
for
(
const
value
&
node
:
module_val
.
at
(
"nodes"
))
{
instruction_ref
output
;
...
...
@@ -455,15 +449,17 @@ void program::from_value(const value& v)
}
auto
module_vals
=
v
.
at
(
"modules"
);
std
::
unordered_map
<
std
::
string
,
module_ref
>
map_mods
;
for
(
const
auto
&
vv
:
module_vals
)
{
const
auto
&
name
=
vv
.
at
(
"name"
).
to
<
std
::
string
>
();
const
auto
&
name
=
vv
.
get_key
();
if
(
name
==
"main"
)
continue
;
impl
->
modules
.
push_back
({
name
});
map_mods
[
name
]
=
&
impl
->
modules
.
back
();
impl
->
modules
.
emplace
(
name
,
name
);
}
std
::
unordered_map
<
std
::
string
,
module_ref
>
map_mods
;
std
::
transform
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
std
::
inserter
(
map_mods
,
map_mods
.
end
()),
[
&
](
auto
&&
pp
)
{
return
std
::
make_pair
(
pp
.
first
,
&
pp
.
second
);
});
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
map_insts
;
auto
*
mm
=
get_main_module
();
...
...
@@ -585,8 +581,8 @@ void program::debug_print() const { std::cout << *this << std::endl; }
void
program
::
debug_print
(
instruction_ref
ins
)
const
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
it
)
{
return
(
it
.
end
()
==
ins
);
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
pp
)
{
return
(
pp
.
second
.
end
()
==
ins
);
}))
{
std
::
cout
<<
"End instruction"
<<
std
::
endl
;
...
...
@@ -594,7 +590,7 @@ void program::debug_print(instruction_ref ins) const
}
else
if
(
std
::
none_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
it
)
{
return
it
.
has_instruction
(
ins
);
}))
[
&
](
const
auto
&
pp
)
{
return
pp
.
second
.
has_instruction
(
ins
);
}))
{
std
::
cout
<<
"Instruction not part of program"
<<
std
::
endl
;
return
;
...
...
@@ -615,9 +611,9 @@ void program::print(
const
std
::
function
<
void
(
instruction_ref
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
{
for
(
const
auto
&
mod
:
this
->
impl
->
modules
)
for
(
const
auto
&
pp
:
this
->
impl
->
modules
)
{
names
=
mo
d
.
print
(
print_func
,
names
);
names
=
pp
.
secon
d
.
print
(
print_func
,
names
);
}
}
...
...
@@ -647,74 +643,72 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
{
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
pp
:
this
->
impl
->
modules
)
{
std
::
cout
<<
mod
.
name
()
<<
":"
<<
std
::
endl
;
mo
d
.
annotate
(
os
,
a
);
std
::
cout
<<
pp
.
first
<<
":"
<<
std
::
endl
;
pp
.
secon
d
.
annotate
(
os
,
a
);
}
}
const
module
*
program
::
get_module
(
const
std
::
string
&
name
)
const
{
auto
it
=
std
::
find_if
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
[
&
](
auto
&
m
)
{
return
(
m
.
name
()
==
name
);
});
if
(
it
==
impl
->
modules
.
end
())
{
return
nullptr
;
}
return
&
(
*
it
);
return
&
impl
->
modules
.
at
(
name
);
}
module
*
program
::
create_module
(
const
std
::
string
&
name
)
{
auto
it
=
impl
->
modules
.
insert
(
impl
->
modules
.
end
()
,
{
name
}
);
return
&
(
*
it
);
auto
r
=
impl
->
modules
.
emplace
(
name
,
name
);
return
&
(
r
.
first
->
second
);
}
module
*
program
::
get_module
(
const
std
::
string
&
name
)
{
auto
it
=
std
::
find_if
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
[
&
](
auto
&
m
)
{
return
(
m
.
name
()
==
name
);
});
if
(
it
==
impl
->
modules
.
end
())
{
return
nullptr
;
}
return
&
(
*
it
);
return
&
impl
->
modules
.
at
(
name
);
}
module
*
program
::
get_main_module
()
{
return
get_module
(
"main"
);
}
const
module
*
program
::
get_main_module
()
const
{
return
get_module
(
"main"
);
}
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
template
<
class
T
>
std
::
vector
<
T
*>
generic_get_modules
(
T
*
mm
)
{
const
module
*
mm
=
this
->
get_main_module
();
std
::
vector
<
const
module
*>
vec_modules
;
std
::
vector
<
T
*>
vec_modules
;
vec_modules
.
push_back
(
mm
);
auto
sub_modules
=
mm
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_modules
.
begin
(),
sub_modules
.
end
());
return
vec_modules
;
}
std
::
vector
<
module
*>
program
::
get_modules
()
template
<
class
Map
,
class
T
,
class
OutputIterator
>
void
generic_get_unused_modules
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
OutputIterator
out
)
{
module
*
mm
=
this
->
get_main_module
();
std
::
vector
<
module
*>
vec_modules
;
vec_modules
.
push_back
(
mm
);
auto
sub_modules
=
mm
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_modules
.
begin
(),
sub_modules
.
end
());
std
::
unordered_set
<
std
::
string
>
used
;
std
::
transform
(
mods
.
begin
(),
mods
.
end
(),
std
::
inserter
(
used
,
used
.
end
()),
[](
auto
&&
mod
)
{
return
mod
->
name
();
});
transform_if
(
m
.
begin
(),
m
.
end
(),
out
,
[
&
](
auto
&&
pp
){
return
not
contains
(
used
,
pp
.
first
);
},
[](
auto
&&
pp
)
{
return
&
pp
.
second
;
});
}
return
vec_modules
;
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
{
auto
result
=
generic_get_modules
(
this
->
get_main_module
());
generic_get_unused_modules
(
impl
->
modules
,
result
,
std
::
back_inserter
(
result
));
return
result
;
}
std
::
vector
<
module
*>
program
::
get_modules
()
{
auto
result
=
generic_get_modules
(
this
->
get_main_module
());
generic_get_unused_modules
(
impl
->
modules
,
result
,
std
::
back_inserter
(
result
));
return
result
;
}
program
&
program
::
sort
()
{
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
pp
:
this
->
impl
->
modules
)
{
mo
d
.
sort
();
pp
.
secon
d
.
sort
();
}
return
*
this
;
...
...
test/eval_test.cpp
View file @
a851f699
...
...
@@ -89,17 +89,17 @@ struct invert_pass
{
std
::
string
name
()
const
{
return
"invert_pass"
;
}
void
apply
(
migraphx
::
module
&
p
)
const
void
apply
(
migraphx
::
module
&
m
)
const
{
for
(
auto
ins
:
migraphx
::
iterator_for
(
p
))
for
(
auto
ins
:
migraphx
::
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"sum"
)
{
p
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
inputs
());
m
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
inputs
());
}
else
if
(
ins
->
name
()
==
"minus"
)
{
p
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
inputs
());
m
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
inputs
());
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment