Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a851f699
Commit
a851f699
authored
Apr 30, 2021
by
Paul
Browse files
Hash modules for quicker lookup of modules
parent
56584fa2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
64 deletions
+72
-64
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+14
-0
src/program.cpp
src/program.cpp
+54
-60
test/eval_test.cpp
test/eval_test.cpp
+4
-4
No files found.
src/include/migraphx/algorithm.hpp
View file @
a851f699
...
...
@@ -7,6 +7,20 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
Output
,
class
Predicate
,
class
F
>
void
transform_if
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
,
F
f
)
{
while
(
start
!=
last
)
{
if
(
pred
(
*
start
))
{
*
out
=
f
(
*
start
);
++
out
;
}
++
start
;
}
}
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
{
...
...
src/program.cpp
View file @
a851f699
...
...
@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
...
...
@@ -26,13 +27,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program_impl
{
// A map is used to keep references to modules of the program
// all the modules are store in the depth-first order
std
::
list
<
module
>
modules
;
std
::
unordered_map
<
std
::
string
,
module
>
modules
;
context
ctx
;
std
::
string
target_name
;
};
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
impl
->
modules
.
push_back
({
"main"
}
);
}
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{
this
->
create_module
(
"main"
);
}
program
::
program
(
program
&&
)
noexcept
=
default
;
program
::~
program
()
noexcept
=
default
;
...
...
@@ -67,9 +67,8 @@ void program::assign(const program& p)
std
::
unordered_map
<
module_ref
,
module_ref
>
mod_map
;
std
::
transform
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
p
.
impl
->
modules
.
begin
(),
std
::
inserter
(
mod_map
,
mod_map
.
begin
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
&
y
,
&
x
);
});
[
&
](
auto
&&
x
p
)
{
return
std
::
make_pair
(
&
p
.
impl
->
modules
.
at
(
xp
.
first
),
&
xp
.
second
);
});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
&&
pp
:
mod_map
)
...
...
@@ -86,7 +85,7 @@ void program::assign(const program& p)
// Update all references from all modules
for
(
auto
&&
mp
:
impl
->
modules
)
{
for
(
auto
ins
:
iterator_for
(
mp
))
for
(
auto
ins
:
iterator_for
(
mp
.
second
))
instruction
::
replace_refs
(
ins
,
ins_map
,
mod_map
);
}
}
...
...
@@ -316,14 +315,14 @@ value program::to_value() const
if
(
not
this
->
impl
->
target_name
.
empty
())
result
[
"context"
]
=
this
->
impl
->
ctx
.
to_value
();
value
module_vals
=
value
::
array
{};
value
module_vals
=
value
::
object
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
mod
:
this
->
get_
modules
()
)
{
value
mod_val
;
value
nodes
;
mod_val
[
"name"
]
=
mod
.
name
();
names
=
mod
.
print
(
mod_val
[
"name"
]
=
mod
->
name
();
names
=
mod
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
value
node
;
node
[
"output"
]
=
ins_names
.
at
(
ins
);
...
...
@@ -358,7 +357,7 @@ value program::to_value() const
names
);
mod_val
[
"nodes"
]
=
nodes
;
module_vals
.
push_back
(
mod_val
)
;
module_vals
[
mod
->
name
()]
=
mod_val
;
}
result
[
"modules"
]
=
module_vals
;
...
...
@@ -371,12 +370,7 @@ static void mod_from_val(module_ref mod,
std
::
unordered_map
<
std
::
string
,
instruction_ref
>&
instructions
,
const
std
::
unordered_map
<
std
::
string
,
module_ref
>&
map_mods
)
{
const
auto
*
it
=
std
::
find_if
(
v
.
begin
(),
v
.
end
(),
[
&
](
auto
&
mv
)
{
return
mv
.
at
(
"name"
).
template
to
<
std
::
string
>()
==
mod
->
name
();
});
assert
(
it
!=
v
.
end
());
const
auto
&
module_val
=
*
it
;
const
auto
&
module_val
=
v
.
at
(
mod
->
name
());
for
(
const
value
&
node
:
module_val
.
at
(
"nodes"
))
{
instruction_ref
output
;
...
...
@@ -455,15 +449,17 @@ void program::from_value(const value& v)
}
auto
module_vals
=
v
.
at
(
"modules"
);
std
::
unordered_map
<
std
::
string
,
module_ref
>
map_mods
;
for
(
const
auto
&
vv
:
module_vals
)
{
const
auto
&
name
=
vv
.
at
(
"name"
).
to
<
std
::
string
>
();
const
auto
&
name
=
vv
.
get_key
();
if
(
name
==
"main"
)
continue
;
impl
->
modules
.
push_back
({
name
});
map_mods
[
name
]
=
&
impl
->
modules
.
back
();
impl
->
modules
.
emplace
(
name
,
name
);
}
std
::
unordered_map
<
std
::
string
,
module_ref
>
map_mods
;
std
::
transform
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
std
::
inserter
(
map_mods
,
map_mods
.
end
()),
[
&
](
auto
&&
pp
)
{
return
std
::
make_pair
(
pp
.
first
,
&
pp
.
second
);
});
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
map_insts
;
auto
*
mm
=
get_main_module
();
...
...
@@ -585,8 +581,8 @@ void program::debug_print() const { std::cout << *this << std::endl; }
void
program
::
debug_print
(
instruction_ref
ins
)
const
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
it
)
{
return
(
it
.
end
()
==
ins
);
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
pp
)
{
return
(
pp
.
second
.
end
()
==
ins
);
}))
{
std
::
cout
<<
"End instruction"
<<
std
::
endl
;
...
...
@@ -594,7 +590,7 @@ void program::debug_print(instruction_ref ins) const
}
else
if
(
std
::
none_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
it
)
{
return
it
.
has_instruction
(
ins
);
}))
[
&
](
const
auto
&
pp
)
{
return
pp
.
second
.
has_instruction
(
ins
);
}))
{
std
::
cout
<<
"Instruction not part of program"
<<
std
::
endl
;
return
;
...
...
@@ -615,9 +611,9 @@ void program::print(
const
std
::
function
<
void
(
instruction_ref
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
{
for
(
const
auto
&
mod
:
this
->
impl
->
modules
)
for
(
const
auto
&
pp
:
this
->
impl
->
modules
)
{
names
=
mo
d
.
print
(
print_func
,
names
);
names
=
pp
.
secon
d
.
print
(
print_func
,
names
);
}
}
...
...
@@ -647,74 +643,72 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void
program
::
annotate
(
std
::
ostream
&
os
,
const
std
::
function
<
void
(
instruction_ref
)
>&
a
)
const
{
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
pp
:
this
->
impl
->
modules
)
{
std
::
cout
<<
mod
.
name
()
<<
":"
<<
std
::
endl
;
mo
d
.
annotate
(
os
,
a
);
std
::
cout
<<
pp
.
first
<<
":"
<<
std
::
endl
;
pp
.
secon
d
.
annotate
(
os
,
a
);
}
}
const
module
*
program
::
get_module
(
const
std
::
string
&
name
)
const
{
auto
it
=
std
::
find_if
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
[
&
](
auto
&
m
)
{
return
(
m
.
name
()
==
name
);
});
if
(
it
==
impl
->
modules
.
end
())
{
return
nullptr
;
}
return
&
(
*
it
);
return
&
impl
->
modules
.
at
(
name
);
}
module
*
program
::
create_module
(
const
std
::
string
&
name
)
{
auto
it
=
impl
->
modules
.
insert
(
impl
->
modules
.
end
()
,
{
name
}
);
return
&
(
*
it
);
auto
r
=
impl
->
modules
.
emplace
(
name
,
name
);
return
&
(
r
.
first
->
second
);
}
module
*
program
::
get_module
(
const
std
::
string
&
name
)
{
auto
it
=
std
::
find_if
(
impl
->
modules
.
begin
(),
impl
->
modules
.
end
(),
[
&
](
auto
&
m
)
{
return
(
m
.
name
()
==
name
);
});
if
(
it
==
impl
->
modules
.
end
())
{
return
nullptr
;
}
return
&
(
*
it
);
return
&
impl
->
modules
.
at
(
name
);
}
module
*
program
::
get_main_module
()
{
return
get_module
(
"main"
);
}
const
module
*
program
::
get_main_module
()
const
{
return
get_module
(
"main"
);
}
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
template
<
class
T
>
std
::
vector
<
T
*>
generic_get_modules
(
T
*
mm
)
{
const
module
*
mm
=
this
->
get_main_module
();
std
::
vector
<
const
module
*>
vec_modules
;
std
::
vector
<
T
*>
vec_modules
;
vec_modules
.
push_back
(
mm
);
auto
sub_modules
=
mm
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_modules
.
begin
(),
sub_modules
.
end
());
return
vec_modules
;
}
std
::
vector
<
module
*>
program
::
get_modules
()
template
<
class
Map
,
class
T
,
class
OutputIterator
>
void
generic_get_unused_modules
(
Map
&
m
,
const
std
::
vector
<
T
*>&
mods
,
OutputIterator
out
)
{
module
*
mm
=
this
->
get_main_module
();
std
::
vector
<
module
*>
vec_modules
;
vec_modules
.
push_back
(
mm
);
auto
sub_modules
=
mm
->
get_sub_modules
();
vec_modules
.
insert
(
vec_modules
.
end
(),
sub_modules
.
begin
(),
sub_modules
.
end
());
std
::
unordered_set
<
std
::
string
>
used
;
std
::
transform
(
mods
.
begin
(),
mods
.
end
(),
std
::
inserter
(
used
,
used
.
end
()),
[](
auto
&&
mod
)
{
return
mod
->
name
();
});
transform_if
(
m
.
begin
(),
m
.
end
(),
out
,
[
&
](
auto
&&
pp
){
return
not
contains
(
used
,
pp
.
first
);
},
[](
auto
&&
pp
)
{
return
&
pp
.
second
;
});
}
return
vec_modules
;
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
{
auto
result
=
generic_get_modules
(
this
->
get_main_module
());
generic_get_unused_modules
(
impl
->
modules
,
result
,
std
::
back_inserter
(
result
));
return
result
;
}
std
::
vector
<
module
*>
program
::
get_modules
()
{
auto
result
=
generic_get_modules
(
this
->
get_main_module
());
generic_get_unused_modules
(
impl
->
modules
,
result
,
std
::
back_inserter
(
result
));
return
result
;
}
program
&
program
::
sort
()
{
for
(
auto
&
mod
:
this
->
impl
->
modules
)
for
(
auto
&
pp
:
this
->
impl
->
modules
)
{
mo
d
.
sort
();
pp
.
secon
d
.
sort
();
}
return
*
this
;
...
...
test/eval_test.cpp
View file @
a851f699
...
...
@@ -89,17 +89,17 @@ struct invert_pass
{
std
::
string
name
()
const
{
return
"invert_pass"
;
}
void
apply
(
migraphx
::
module
&
p
)
const
void
apply
(
migraphx
::
module
&
m
)
const
{
for
(
auto
ins
:
migraphx
::
iterator_for
(
p
))
for
(
auto
ins
:
migraphx
::
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"sum"
)
{
p
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
inputs
());
m
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
inputs
());
}
else
if
(
ins
->
name
()
==
"minus"
)
{
p
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
inputs
());
m
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
inputs
());
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment