Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2a38cd4
Commit
d2a38cd4
authored
Aug 12, 2018
by
Paul
Browse files
Add simplify reshapes pass
parent
fc8ff61f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
227 additions
and
3 deletions
+227
-3
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraph/simplify_reshapes.hpp
src/include/migraph/simplify_reshapes.hpp
+19
-0
src/program.cpp
src/program.cpp
+0
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+58
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+145
-0
No files found.
src/CMakeLists.txt
View file @
d2a38cd4
...
@@ -5,6 +5,7 @@ add_library(migraph
...
@@ -5,6 +5,7 @@ add_library(migraph
generate.cpp
generate.cpp
program.cpp
program.cpp
shape.cpp
shape.cpp
simplify_reshapes.cpp
)
)
rocm_clang_tidy_check
(
migraph
)
rocm_clang_tidy_check
(
migraph
)
target_include_directories
(
migraph PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
target_include_directories
(
migraph PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
...
...
src/include/migraph/simplify_reshapes.hpp
0 → 100644
View file @
d2a38cd4
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
simplify_reshapes
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/program.cpp
View file @
d2a38cd4
...
@@ -65,7 +65,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -65,7 +65,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
output
.
empty
())
if
(
ins
->
output
.
empty
())
{
{
remove_instruction
(
ins
);
return
rep
;
return
rep
;
}
}
for
(
auto
&&
out
:
ins
->
output
)
for
(
auto
&&
out
:
ins
->
output
)
...
@@ -80,8 +79,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -80,8 +79,6 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
output
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
!
rep
->
output
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
ins
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
if
(
ins
->
output
.
empty
())
remove_instruction
(
ins
);
assert
(
rep
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
return
rep
;
return
rep
;
}
}
...
...
src/simplify_reshapes.cpp
0 → 100644
View file @
d2a38cd4
#include <migraph/simplify_reshapes.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <unordered_set>
namespace
migraph
{
bool
is_reshaper
(
const
std
::
string
&
name
)
{
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"reshape"
,
"transpose"
,
// "broadcast",
"contiguous"
};
return
contains
(
names
,
name
);
}
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
not
is_reshaper
(
ins
->
op
.
name
()))
continue
;
if
(
ins
->
output
.
size
()
!=
1
)
continue
;
if
(
is_reshaper
(
ins
->
output
.
front
()
->
op
.
name
()))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
op
.
name
()))
{
assert
(
!
reshapes
.
back
()
->
arguments
.
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
arguments
.
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
arguments
.
front
());
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
result
==
(
*
start
)
->
result
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
// namespace migraph
src/targets/gpu/target.cpp
View file @
d2a38cd4
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/check_context.hpp>
#include <migraph/check_context.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -14,8 +16,10 @@ std::vector<pass> target::get_passes(migraph::context&) const
...
@@ -14,8 +16,10 @@ std::vector<pass> target::get_passes(migraph::context&) const
return
return
{
{
auto_contiguous
{},
auto_contiguous
{},
simplify_reshapes
{},
lowering
{},
lowering
{},
write_literals
{},
write_literals
{},
dead_code_elimination
{},
check_context
<
context
>
{}
check_context
<
context
>
{}
};
};
// clang-format on
// clang-format on
...
...
test/simplify_reshapes_test.cpp
0 → 100644
View file @
d2a38cd4
#include <migraph/simplify_reshapes.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
simplify_reshapes_target
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
simplify_reshapes
{},
migraph
::
dead_code_elimination
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2x2_transposed
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
migraph
::
literal
get_2_broadcasted
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
void
double_contig
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c1
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t1
);
auto
c2
=
p
.
add_instruction
(
migraph
::
contiguous
{},
c1
);
p
.
add_instruction
(
pass_op
{},
c2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
double_transpose
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
t2
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
double_transpose_contig
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c1
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t1
);
auto
t2
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
c1
);
auto
c2
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t2
);
p
.
add_instruction
(
pass_op
{},
c2
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
single_transpose
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
pass_op
{},
t1
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
3
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
!=
get_2x2
());
}
void
double_transpose_sin_pass
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t1
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
t1
);
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
// std::cout << p << std::endl;
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
get_2x2
());
}
void
single_transpose_sin_pass
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
auto
result
=
p
.
eval
({});
EXPECT
(
result
!=
get_2x2
());
}
int
main
()
{
double_contig
();
double_transpose
();
double_transpose_contig
();
single_transpose
();
double_transpose_sin_pass
();
single_transpose_sin_pass
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment