Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b889693c
Commit
b889693c
authored
Jul 08, 2019
by
Paul
Browse files
Convert batchnorm to multiply and add
parent
d5ade1e7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
36 deletions
+30
-36
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+29
-35
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+1
-1
No files found.
src/fwd_conv_batchnorm_rewrite.cpp
View file @
b889693c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
...
@@ -25,46 +26,39 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -25,46 +26,39 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
s
=
shape
{
ins
->
get_shape
().
type
(),
{
ins
->
get_shape
().
lens
()[
1
]}};
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
argument
a
{
s
};
auto
weights_lens
=
weights
.
get_shape
().
lens
();
argument
b
{
s
};
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
argument
new_weights
{
weights
.
get_shape
()};
[
&
](
auto
gamma2
,
argument
new_bias
{{
bias
.
get_shape
().
type
(),
{
bias
.
get_shape
().
elements
()}}};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
auto
gamma2
,
auto
bias2
,
auto
bias2
,
auto
mean2
,
auto
mean2
,
auto
variance2
,
auto
variance2
,
auto
new_weights
2
,
auto
a
2
,
auto
new_bias
2
)
{
auto
b
2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
]
)(
dfor
(
a
.
get_shape
().
elements
()
)(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
c
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
a2
[
c
]
=
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
)
;
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias
2
[
c
]
=
b
2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()[
0
],
l_weights
});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
b
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
c
->
get_shape
().
lens
()},
l_bias
);
auto
mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
ins
->
inputs
().
front
(),
a_broadcast
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
{
c
,
b
});
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
b889693c
...
@@ -96,7 +96,7 @@ TEST_CASE(non_literal)
...
@@ -96,7 +96,7 @@ TEST_CASE(non_literal)
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any
_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none
_of
(
p2
,
&
is_batch_norm
));
}
}
TEST_CASE
(
as_literal
)
TEST_CASE
(
as_literal
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment