Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2a38cd4
Commit
d2a38cd4
authored
Aug 12, 2018
by
Paul
Browse files
Add simplify reshapes pass
parent
fc8ff61f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
227 additions
and
3 deletions
+227
-3
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraph/simplify_reshapes.hpp
src/include/migraph/simplify_reshapes.hpp
+19
-0
src/program.cpp
src/program.cpp
+0
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+58
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+145
-0
No files found.
src/CMakeLists.txt
View file @
d2a38cd4
...
@@ -5,6 +5,7 @@ add_library(migraph
...
@@ -5,6 +5,7 @@ add_library(migraph
generate.cpp
generate.cpp
program.cpp
program.cpp
shape.cpp
shape.cpp
simplify_reshapes.cpp
)
)
rocm_clang_tidy_check
(
migraph
)
rocm_clang_tidy_check
(
migraph
)
target_include_directories
(
migraph PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraph PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
...
...
src/include/migraph/simplify_reshapes.hpp
0 → 100644
View file @
d2a38cd4
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
simplify_reshapes
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/program.cpp
View file @
d2a38cd4
...
@@ -65,7 +65,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -65,7 +65,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
output
.
empty
())
if
(
ins
->
output
.
empty
())
{
{
remove_instruction
(
ins
);
return
rep
;
return
rep
;
}
}
for
(
auto
&&
out
:
ins
->
output
)
for
(
auto
&&
out
:
ins
->
output
)
...
@@ -80,8 +79,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -80,8 +79,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
output
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
!
rep
->
output
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
ins
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
if
(
ins
->
output
.
empty
())
remove_instruction
(
ins
);
assert
(
rep
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
return
rep
;
return
rep
;
}
}
...
...
src/simplify_reshapes.cpp
0 → 100644
View file @
d2a38cd4
#include <migraph/simplify_reshapes.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <unordered_set>
namespace
migraph
{
bool
is_reshaper
(
const
std
::
string
&
name
)
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"transpose"
,
// "broadcast",
"contiguous"
};
return
contains
(
names
,
name
);
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
not
is_reshaper
(
ins
->
op
.
name
()))
continue
;
if
(
ins
->
output
.
size
()
!=
1
)
continue
;
if
(
is_reshaper
(
ins
->
output
.
front
()
->
op
.
name
()))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
op
.
name
()))
{
assert
(
!
reshapes
.
back
()
->
arguments
.
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
arguments
.
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
arguments
.
front
());
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
result
==
(
*
start
)
->
result
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
// namespace migraph
src/targets/gpu/target.cpp
View file @
d2a38cd4
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/check_context.hpp>
#include <migraph/check_context.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -14,8 +16,10 @@ std::vector<pass> target::get_passes(migraph::context&) const
...
@@ -14,8 +16,10 @@ std::vector<pass> target::get_passes(migraph::context&) const
return
return
{
{
auto_contiguous
{},
auto_contiguous
{},
simplify_reshapes
{},
lowering
{},
lowering
{},
write_literals
{},
write_literals
{},
dead_code_elimination
{},
check_context
<
context
>
{}
check_context
<
context
>
{}
};
};
// clang-format on
// clang-format on
...
...
test/simplify_reshapes_test.cpp
0 → 100644
View file @
d2a38cd4
#include <migraph/simplify_reshapes.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
simplify_reshapes_target
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
simplify_reshapes
{},
migraph
::
dead_code_elimination
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2x2_transposed
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
migraph
::
literal
get_2_broadcasted
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
void
double_contig
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c1
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t1
);
auto
c2
=
p
.
add_instruction
(
migraph
::
contiguous
{},
c1
);
p
.
add_instruction
(
pass_op
{},
c2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
double_transpose
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
t2
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
double_transpose_contig
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c1
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t1
);
auto
t2
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
c1
);
auto
c2
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t2
);
p
.
add_instruction
(
pass_op
{},
c2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
single_transpose
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
pass_op
{},
t1
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
3
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
!=
get_2x2
());
}
void
double_transpose_sin_pass
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
t1
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
// std::cout << p << std::endl;
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
single_transpose_sin_pass
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
!=
get_2x2
());
}
int
main
()
{
double_contig
();
double_transpose
();
double_transpose_contig
();
single_transpose
();
double_transpose_sin_pass
();
single_transpose_sin_pass
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment