Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
658a32ab
Commit
658a32ab
authored
Mar 15, 2019
by
Paul
Browse files
Use eval instead of literal for batchnorm rewrite
parent
efe3a9f0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
155 additions
and
19 deletions
+155
-19
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+14
-15
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+4
-4
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+113
-0
No files found.
src/fwd_conv_batchnorm_rewrite.cpp
View file @
658a32ab
...
...
@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
...
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
if
(
not
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
name
()
==
"@literal"
;
}))
// Get scale, bias, mean, variance from inputs
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
eval
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
eval
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
if
(
conv_ins
->
inputs
()[
1
]
->
name
()
!=
"@literal"
)
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
get_literal
();
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
get_literal
();
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
argument
new_bias
{
{
bias
.
get_shape
()
.
type
(),
{
bias
.
get_shape
().
elements
()}}
};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
auto
gamma2
,
...
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
bias2
(
c
)
-
(
gamma2
(
c
)
*
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
new_bias2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
// Replace convolution instruction with updated weights
...
...
src/include/migraphx/ranges.hpp
View file @
658a32ab
...
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
...
...
src/targets/cpu/lowering.cpp
View file @
658a32ab
...
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
assert
((
variance
(
c
)
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
(
c
)
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
(
c
)
)
/
std
::
sqrt
(
variance
(
c
)
+
epsilon
)
+
bias
(
c
)
;
assert
((
variance
[
c
]
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
[
c
]
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
[
c
]
)
/
std
::
sqrt
(
variance
[
c
]
+
epsilon
)
+
bias
[
c
]
;
});
});
}
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
658a32ab
...
...
@@ -3,9 +3,16 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
bool
is_batch_norm
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"batch_norm_inference"
;
}
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
std
::
vector
<
float
>
xdata
=
{
...
...
@@ -65,4 +72,110 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
}
TEST_CASE
(
non_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p2
,
&
is_batch_norm
));
}
TEST_CASE
(
as_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
TEST_CASE
(
literal_reshape
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
reshape
=
[
&
](
auto
ins
){
return
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
1
,
1
}},
ins
);
};
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
))));
auto
bias
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
))));
auto
mean
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
))));
auto
variance
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
))));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment