Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
658a32ab
"vscode:/vscode.git/clone" did not exist on "e12e38fc50990b1d138166bfd4b05743872f7fee"
Commit
658a32ab
authored
Mar 15, 2019
by
Paul
Browse files
Use eval instead of literal for batchnorm rewrite
parent
efe3a9f0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
155 additions
and
19 deletions
+155
-19
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+14
-15
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+4
-4
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+113
-0
No files found.
src/fwd_conv_batchnorm_rewrite.cpp
View file @
658a32ab
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
if
(
not
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
// Get scale, bias, mean, variance from inputs
return
arg
->
name
()
==
"@literal"
;
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
}))
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
eval
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
eval
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
conv_ins
=
ins
->
inputs
()[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
continue
;
if
(
conv_ins
->
inputs
()[
1
]
->
name
()
!=
"@literal"
)
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
get_literal
();
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
get_literal
();
// Get convolution op
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
auto
conv_op
=
conv_ins
->
get_operator
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
argument
new_bias
{
{
bias
.
get_shape
()
.
type
(),
{
bias
.
get_shape
().
elements
()}}
};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
[
&
](
auto
weights2
,
auto
gamma2
,
auto
gamma2
,
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
new_bias2
[
c
]
=
bias2
(
c
)
-
(
gamma2
(
c
)
*
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
// Replace convolution instruction with updated weights
...
...
src/include/migraphx/ranges.hpp
View file @
658a32ab
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/targets/cpu/lowering.cpp
View file @
658a32ab
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
assert
((
variance
(
c
)
+
epsilon
)
>
0
);
assert
((
variance
[
c
]
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
(
c
)
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
(
c
)
)
/
result
(
n
,
c
,
h
,
w
)
=
gamma
[
c
]
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
[
c
]
)
/
std
::
sqrt
(
variance
(
c
)
+
epsilon
)
+
std
::
sqrt
(
variance
[
c
]
+
epsilon
)
+
bias
(
c
)
;
bias
[
c
]
;
});
});
});
});
}
}
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
658a32ab
...
@@ -3,9 +3,16 @@
...
@@ -3,9 +3,16 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
bool
is_batch_norm
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"batch_norm_inference"
;
}
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
{
std
::
vector
<
float
>
xdata
=
{
std
::
vector
<
float
>
xdata
=
{
...
@@ -65,4 +72,110 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
...
@@ -65,4 +72,110 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
}
}
TEST_CASE
(
non_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p2
,
&
is_batch_norm
));
}
TEST_CASE
(
as_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
TEST_CASE
(
literal_reshape
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
reshape
=
[
&
](
auto
ins
){
return
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
1
,
1
}},
ins
);
};
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
))));
auto
bias
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
))));
auto
mean
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
))));
auto
variance
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
))));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment