Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a96fae91
Commit
a96fae91
authored
Aug 14, 2018
by
Paul
Browse files
Add a pass to eliminate contiguous operators
parent
f0604d78
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
165 additions
and
34 deletions
+165
-34
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+48
-0
src/include/migraph/eliminate_contiguous.hpp
src/include/migraph/eliminate_contiguous.hpp
+19
-0
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+3
-0
src/include/migraph/ranges.hpp
src/include/migraph/ranges.hpp
+6
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-0
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+0
-17
test/eliminate_contiguous_test.cpp
test/eliminate_contiguous_test.cpp
+45
-0
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+41
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+0
-17
No files found.
src/CMakeLists.txt
View file @
a96fae91
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
add_library
(
migraph
add_library
(
migraph
auto_contiguous.cpp
auto_contiguous.cpp
dead_code_elimination.cpp
dead_code_elimination.cpp
eliminate_contiguous.cpp
generate.cpp
generate.cpp
program.cpp
program.cpp
shape.cpp
shape.cpp
...
...
src/eliminate_contiguous.cpp
0 → 100644
View file @
a96fae91
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
namespace
migraph
{
bool
try_compute_shape
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
)
{
try
{
compute_shape
(
op
,
args
);
}
catch
(...)
{
return
false
;
}
return
true
;
}
void
eliminate_contiguous
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
arguments
;
for
(
auto
arg
:
ins
->
arguments
)
{
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
if
(
ends_with
(
arg
->
op
.
name
(),
"contiguous"
))
{
auto
new_args
=
args
;
auto
prev
=
arg
->
arguments
.
front
();
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
->
op
,
new_args
))
{
replace_argument
(
ins
,
arg
,
prev
);
}
}
}
}
}
}
// namespace migraph
src/include/migraph/eliminate_contiguous.hpp
0 → 100644
View file @
a96fae91
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
eliminate_contiguous
{
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/migraph/instruction.hpp
View file @
a96fae91
...
@@ -24,6 +24,7 @@ struct instruction
...
@@ -24,6 +24,7 @@ struct instruction
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
// internal
void
replace
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
void
replace
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
{
{
op
=
o
;
op
=
o
;
...
@@ -46,12 +47,14 @@ struct instruction
...
@@ -46,12 +47,14 @@ struct instruction
void
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
// internal
void
replace
(
std
::
vector
<
instruction_ref
>
args
)
void
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
{
clear_arguments
();
clear_arguments
();
arguments
=
std
::
move
(
args
);
arguments
=
std
::
move
(
args
);
}
}
// internal
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
...
...
src/include/migraph/ranges.hpp
View file @
a96fae91
...
@@ -17,6 +17,12 @@ void copy(Range&& r, Iterator it)
...
@@ -17,6 +17,12 @@ void copy(Range&& r, Iterator it)
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
}
template
<
class
Range
,
class
T
>
void
replace
(
Range
&&
r
,
const
T
&
old
,
const
T
&
new_x
)
{
std
::
replace
(
r
.
begin
(),
r
.
end
(),
old
,
new_x
);
}
template
<
class
Iterator
>
template
<
class
Iterator
>
struct
iterator_range
struct
iterator_range
{
{
...
...
src/targets/gpu/target.cpp
View file @
a96fae91
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraph/auto_contiguous.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp>
#include <migraph/simplify_reshapes.hpp>
#include <migraph/eliminate_contiguous.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -19,6 +20,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
...
@@ -19,6 +20,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
auto_contiguous
{},
auto_contiguous
{},
simplify_reshapes
{},
simplify_reshapes
{},
lowering
{
ctx
},
lowering
{
ctx
},
eliminate_contiguous
{},
write_literals
{},
write_literals
{},
check_context
<
context
>
{},
check_context
<
context
>
{},
dead_code_elimination
{}
dead_code_elimination
{}
...
...
test/auto_contiguous_test.cpp
View file @
a96fae91
...
@@ -13,23 +13,6 @@ struct contiguous_target
...
@@ -13,23 +13,6 @@ struct contiguous_target
migraph
::
context
get_context
()
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2x2_transposed
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
migraph
::
literal
get_2_broadcasted
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
void
literal_broadcast
()
void
literal_broadcast
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
...
test/eliminate_contiguous_test.cpp
0 → 100644
View file @
a96fae91
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
eliminate_contiguous_target
{
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
eliminate_contiguous
{},
migraph
::
dead_code_elimination
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
standard_op
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_standard_op
{},
c
);
auto
count
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
void
non_standard_op
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraph
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c
);
auto
count
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
(
count
-
1
));
}
int
main
()
{
standard_op
();
non_standard_op
();
}
test/include/basic_ops.hpp
View file @
a96fae91
...
@@ -81,6 +81,30 @@ struct pass_op
...
@@ -81,6 +81,30 @@ struct pass_op
}
}
};
};
struct
pass_standard_op
{
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
if
(
args
.
empty
())
return
{};
return
args
.
front
();
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
inputs
)
const
{
for
(
auto
&&
input
:
inputs
)
{
if
(
not
input
.
standard
())
throw
std
::
runtime_error
(
"Not standard shape"
);
}
if
(
inputs
.
empty
())
return
{};
return
inputs
.
front
();
}
};
struct
nop
struct
nop
{
{
std
::
string
name
()
const
{
return
"nop"
;
}
std
::
string
name
()
const
{
return
"nop"
;
}
...
@@ -92,3 +116,20 @@ struct nop
...
@@ -92,3 +116,20 @@ struct nop
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
{
return
{};
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
{
return
{};
}
};
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2x2_transposed
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
migraph
::
literal
get_2_broadcasted
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
test/simplify_reshapes_test.cpp
View file @
a96fae91
...
@@ -14,23 +14,6 @@ struct simplify_reshapes_target
...
@@ -14,23 +14,6 @@ struct simplify_reshapes_target
migraph
::
context
get_context
()
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2x2_transposed
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
migraph
::
literal
get_2_broadcasted
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
1
},
{
1
,
0
}},
{
1
,
2
}};
}
void
double_contig
()
void
double_contig
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment