Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
95a5ba16
Commit
95a5ba16
authored
Mar 27, 2019
by
Khalique
Browse files
fixed tests or pad_rewrite
parent
ce445cad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
106 additions
and
22 deletions
+106
-22
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+3
-1
src/include/migraphx/pad_rewrite.hpp
src/include/migraphx/pad_rewrite.hpp
+4
-2
src/pad_rewrite.cpp
src/pad_rewrite.cpp
+17
-18
test/eliminate_identity_test.cpp
test/eliminate_identity_test.cpp
+1
-1
test/pad_rewrite_test.cpp
test/pad_rewrite_test.cpp
+81
-0
No files found.
src/include/migraphx/eliminate_identity.hpp
View file @
95a5ba16
...
...
@@ -11,7 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
/**
* Remove identity instructions.
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
*/
struct
eliminate_identity
{
...
...
src/include/migraphx/pad_rewrite.hpp
View file @
95a5ba16
...
...
@@ -13,14 +13,16 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
/**
* Rewrite pads to use attribute from other instructions instead.
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
*/
struct
pad_rewrite
{
std
::
string
name
()
const
{
return
"pad_rewrite"
;
}
void
apply
(
program
&
p
)
const
;
template
<
class
T
>
void
update_op
(
T
,
instruction_ref
in
s
,
instruction_ref
output
,
program
&
p
)
const
;
void
update_op
(
T
,
const
instruction_ref
&
in
put
,
const
instruction_ref
&
ins
,
program
&
p
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/pad_rewrite.cpp
View file @
95a5ba16
...
...
@@ -12,39 +12,38 @@ void pad_rewrite::apply(program& p) const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
!=
"pad"
)
const
std
::
string
&
op_name
=
ins
->
name
();
if
(
op_name
!=
"convolution"
and
op_name
!=
"im2col"
and
op_name
!=
"pooling"
)
continue
;
auto
input
=
ins
->
inputs
().
front
();
if
(
input
->
name
()
!=
"pad"
)
continue
;
for
(
auto
output
:
ins
->
outputs
())
{
auto
op_name
=
output
->
name
();
if
(
op_name
==
"convolution"
)
update_op
(
op
::
convolution
{},
in
s
,
output
,
p
);
update_op
(
op
::
convolution
{},
in
put
,
ins
,
p
);
else
if
(
op_name
==
"im2col"
)
update_op
(
op
::
im2col
{},
in
s
,
output
,
p
);
update_op
(
op
::
im2col
{},
in
put
,
ins
,
p
);
else
if
(
op_name
==
"pooling"
)
update_op
(
op
::
pooling
{},
ins
,
output
,
p
);
}
update_op
(
op
::
pooling
{},
input
,
ins
,
p
);
}
}
template
<
class
T
>
void
pad_rewrite
::
update_op
(
T
,
instruction_ref
in
s
,
instruction_ref
output
,
program
&
p
)
const
void
pad_rewrite
::
update_op
(
T
,
const
instruction_ref
&
in
put
,
const
instruction_ref
&
ins
,
program
&
p
)
const
{
auto
pad_op
=
any_cast
<
op
::
pad
>
(
in
s
->
get_operator
());
auto
pad_op
=
any_cast
<
op
::
pad
>
(
in
put
->
get_operator
());
if
(
!
pad_op
.
symmetric
())
return
;
std
::
vector
<
int64_t
>
pads
=
pad_op
.
pads
;
assert
(
pads
.
size
()
==
8
);
// ensure input being padded has 4 dims (*2 for font and back padding)
std
::
array
<
size_t
,
2
>
new_pads
{
static_cast
<
size_t
>
(
pads
[
2
]),
static_cast
<
size_t
>
(
pads
[
3
])};
T
op
=
any_cast
<
T
>
(
output
->
get_operator
());
T
op
=
any_cast
<
T
>
(
ins
->
get_operator
());
op
.
padding
=
new_pads
;
std
::
vector
<
instruction_ref
>
new_inputs
{
output
->
inputs
()};
new_inputs
.
front
()
=
in
s
->
inputs
().
front
();
std
::
vector
<
instruction_ref
>
new_inputs
{
ins
->
inputs
()};
new_inputs
.
front
()
=
in
put
->
inputs
().
front
();
p
.
replace_instruction
(
output
,
op
,
new_inputs
);
p
.
replace_instruction
(
ins
,
op
,
new_inputs
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
test/eliminate_identity_test.cpp
View file @
95a5ba16
...
...
@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency)
p
.
add_instruction
(
sum_op
{},
ans
,
three
);
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
ans
);
p
.
compile
(
eliminate_identity_target
{});
EXPECT
(
!
std
::
none
_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
EXPECT
(
std
::
any
_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"identity"
;
}));
auto
result
=
p
.
eval
({});
...
...
test/pad_rewrite_test.cpp
0 → 100644
View file @
95a5ba16
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct
pad_rewrite_target
{
std
::
string
name
()
const
{
return
"pad_rewrite"
;
}
std
::
vector
<
migraphx
::
pass
>
get_passes
(
migraphx
::
context
&
)
const
{
return
{
migraphx
::
pad_rewrite
{},
migraphx
::
dead_code_elimination
{}};
}
migraphx
::
context
get_context
()
const
{
return
{};
}
};
migraphx
::
instruction_ref
create_im2col
(
migraphx
::
instruction_ref
&
l_img
,
size_t
channels
,
migraphx
::
program
&
p
)
{
size_t
f
[
2
]
=
{
1
,
1
};
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
migraphx
::
shape
s_weights
{
migraphx
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_weights
=
p
.
add_literal
(
migraphx
::
literal
{
s_weights
,
weights
});
return
p
.
add_instruction
(
migraphx
::
op
::
im2col
{},
l_img
,
l_weights
);
}
migraphx
::
instruction_ref
create_conv
(
migraphx
::
instruction_ref
&
l_img
,
size_t
channels
,
migraphx
::
program
&
p
)
{
migraphx
::
shape
s_weights
{
migraphx
::
shape
::
int32_type
,
{
4
,
channels
,
3
,
3
}};
std
::
vector
<
int32_t
>
weights
(
4
*
channels
*
3
*
3
);
auto
l_weights
=
p
.
add_literal
(
migraphx
::
literal
{
s_weights
,
weights
});
return
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
l_img
,
l_weights
);
}
TEST_CASE
(
rewrite_test
)
{
migraphx
::
program
p
;
size_t
img_dim
[
2
]
=
{
2
,
2
};
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
input
(
channels
*
img_dim
[
0
]
*
img_dim
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraphx
::
shape
s_img
{
migraphx
::
shape
::
int32_type
,
{
1
,
channels
,
img_dim
[
0
],
img_dim
[
1
]}};
auto
l_img
=
p
.
add_literal
(
migraphx
::
literal
{
s_img
,
input
});
auto
padded_img
=
p
.
add_instruction
(
migraphx
::
op
::
pad
{{
0
,
0
,
1
,
1
,
0
,
0
,
1
,
1
}},
l_img
);
auto
l0
=
create_im2col
(
padded_img
,
channels
,
p
);
auto
l1
=
create_conv
(
padded_img
,
channels
,
p
);
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
pooling
{},
padded_img
);
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
l0
,
l1
,
l2
);
p
.
compile
(
pad_rewrite_target
{});
EXPECT
(
std
::
none_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
TEST_CASE
(
rewrite_test_asymmetric
)
{
migraphx
::
program
p
;
size_t
img_dim
[
2
]
=
{
2
,
2
};
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
input
(
channels
*
img_dim
[
0
]
*
img_dim
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraphx
::
shape
s_img
{
migraphx
::
shape
::
int32_type
,
{
1
,
channels
,
img_dim
[
0
],
img_dim
[
1
]}};
auto
l_img
=
p
.
add_literal
(
migraphx
::
literal
{
s_img
,
input
});
auto
padded_img
=
p
.
add_instruction
(
migraphx
::
op
::
pad
{{
0
,
0
,
0
,
0
,
0
,
0
,
2
,
2
}},
l_img
);
create_im2col
(
padded_img
,
channels
,
p
);
p
.
compile
(
pad_rewrite_target
{});
EXPECT
(
std
::
any_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment