Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5613e3a7
Commit
5613e3a7
authored
Oct 24, 2018
by
Paul
Browse files
Fix test in cse
parent
dd9ff577
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
75 additions
and
31 deletions
+75
-31
src/common_subexpression_elimination.cpp
src/common_subexpression_elimination.cpp
+2
-0
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+5
-1
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+4
-0
src/instruction.cpp
src/instruction.cpp
+1
-6
src/program.cpp
src/program.cpp
+63
-24
No files found.
src/common_subexpression_elimination.cpp
View file @
5613e3a7
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
#include <migraph/ranges.hpp>
#include <migraph/ranges.hpp>
#include <migraph/functional.hpp>
#include <migraph/functional.hpp>
#include <unordered_set>
namespace
migraph
{
namespace
migraph
{
template
<
class
Range
>
template
<
class
Range
>
...
...
src/include/migraph/instruction.hpp
View file @
5613e3a7
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <string>
#include <utility>
#include <utility>
...
@@ -56,7 +57,10 @@ struct instruction
...
@@ -56,7 +57,10 @@ struct instruction
void
add_output
(
instruction_ref
ins
);
void
add_output
(
instruction_ref
ins
);
template
<
class
T
>
template
<
class
T
>
void
remove_output
(
const
T
&
ins
);
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
static
void
backreference
(
instruction_ref
ref
);
static
void
backreference
(
instruction_ref
ref
);
...
...
src/include/migraph/program.hpp
View file @
5613e3a7
...
@@ -95,6 +95,10 @@ struct program
...
@@ -95,6 +95,10 @@ struct program
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
debug_print
();
void
debug_print
(
instruction_ref
ins
);
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
src/instruction.cpp
View file @
5613e3a7
...
@@ -117,12 +117,6 @@ void instruction::add_output(instruction_ref ins)
...
@@ -117,12 +117,6 @@ void instruction::add_output(instruction_ref ins)
output
.
push_back
(
ins
);
output
.
push_back
(
ins
);
}
}
template
<
class
T
>
void
instruction
::
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
instruction
::
backreference
(
instruction_ref
ref
)
void
instruction
::
backreference
(
instruction_ref
ref
)
{
{
for
(
auto
&&
arg
:
ref
->
inputs
())
for
(
auto
&&
arg
:
ref
->
inputs
())
...
@@ -162,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args)
...
@@ -162,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args)
void
instruction
::
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
void
instruction
::
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
{
assert
(
std
::
any_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
auto
i
)
{
return
i
==
old
;
}));
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
old
->
remove_output
(
*
this
);
}
}
...
...
src/program.cpp
View file @
5613e3a7
...
@@ -23,21 +23,9 @@ struct program_impl
...
@@ -23,21 +23,9 @@ struct program_impl
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
get_operator
();
}
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
get_operator
();
}
template
<
class
F
>
static
void
print_instruction
(
std
::
ostream
&
os
,
instruction_ref
ins
,
const
std
::
unordered_map
<
instruction_ref
,
std
::
string
>&
names
)
static
void
print_program
(
std
::
ostream
&
os
,
const
program
&
p
,
F
annonate
)
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
names
.
at
(
ins
)
<<
" = "
;
int
count
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
if
(
ins
->
name
()
==
"@param"
)
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
}
os
<<
var_name
<<
" = "
;
os
<<
ins
->
get_operator
();
os
<<
ins
->
get_operator
();
...
@@ -54,7 +42,6 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -54,7 +42,6 @@ static void print_program(std::ostream& os, const program& p, F annonate)
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
inputs
())
for
(
auto
&&
arg
:
ins
->
inputs
())
{
{
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
os
<<
delim
<<
names
.
at
(
arg
);
os
<<
delim
<<
names
.
at
(
arg
);
delim
=
','
;
delim
=
','
;
}
}
...
@@ -62,12 +49,36 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -62,12 +49,36 @@ static void print_program(std::ostream& os, const program& p, F annonate)
}
}
os
<<
" -> "
<<
ins
->
get_shape
();
os
<<
" -> "
<<
ins
->
get_shape
();
}
template
<
class
F
>
static
void
print_program
(
std
::
ostream
&
os
,
const
program
&
p
,
F
annonate
)
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
int
count
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
if
(
ins
->
name
()
==
"@param"
)
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
}
names
.
emplace
(
ins
,
var_name
);
// TODO: Use all_of
for
(
auto
&&
arg
:
ins
->
inputs
())
{
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
(
void
)
arg
;
}
print_instruction
(
os
,
ins
,
names
);
annonate
(
ins
,
names
);
annonate
(
ins
,
names
);
os
<<
std
::
endl
;
os
<<
std
::
endl
;
names
.
emplace
(
ins
,
var_name
);
count
++
;
count
++
;
}
}
}
}
...
@@ -124,7 +135,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -124,7 +135,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
{
{
return
rep
;
return
rep
;
}
}
for
(
auto
&&
out
:
ins
->
outputs
())
// Make a copy of outputs which can be changed when calling replace_argument
auto
outputs
=
ins
->
outputs
();
for
(
auto
out
:
outputs
)
{
{
// TODO: Check for possible cycles
// TODO: Check for possible cycles
if
(
out
!=
rep
)
if
(
out
!=
rep
)
...
@@ -135,6 +148,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -135,6 +148,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
}
}
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
!
rep
->
outputs
().
empty
()
or
rep
==
std
::
prev
(
end
()));
// Output of the original instruction should only be the replacement or empty
assert
(
ins
->
outputs
().
empty
()
or
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
i
==
rep
;
}));
assert
(
ins
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
return
rep
;
return
rep
;
...
@@ -449,6 +466,28 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
...
@@ -449,6 +466,28 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
<<
", "
<<
std
::
round
(
calculate_overhead_percent
)
<<
"%"
<<
std
::
endl
;
<<
", "
<<
std
::
round
(
calculate_overhead_percent
)
<<
"%"
<<
std
::
endl
;
}
}
void
program
::
debug_print
()
{
std
::
cout
<<
*
this
<<
std
::
endl
;
}
void
program
::
debug_print
(
instruction_ref
ins
)
{
std
::
stringstream
ss
;
print_program
(
ss
,
*
this
,
[
&
](
auto
x
,
auto
&&
names
)
{
if
(
x
==
ins
)
{
print_instruction
(
std
::
cout
,
x
,
names
);
std
::
cout
<<
std
::
endl
;
}
});
}
void
program
::
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
{
for
(
auto
ins
:
inss
)
debug_print
(
ins
);
std
::
cout
<<
std
::
endl
;
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment