Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15ba8a36
Unverified
Commit
15ba8a36
authored
Mar 29, 2019
by
mvermeulen
Committed by
GitHub
Mar 29, 2019
Browse files
Merge pull request #214 from ROCmSoftwarePlatform/batchnorm-rewrite
Improve batchnorm rewrite
parents
da782877
22286f71
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
164 additions
and
19 deletions
+164
-19
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+14
-15
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+12
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-0
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+105
-0
No files found.
src/fwd_conv_batchnorm_rewrite.cpp
View file @
15ba8a36
...
...
@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
...
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
if
(
not
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
name
()
==
"@literal"
;
}))
// Get scale, bias, mean, variance from inputs
auto
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
auto
bias
=
ins
->
inputs
()[
2
]
->
eval
();
auto
mean
=
ins
->
inputs
()[
3
]
->
eval
();
auto
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
if
(
conv_ins
->
inputs
()[
1
]
->
name
()
!=
"@literal"
)
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
get_literal
();
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
get_literal
();
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
argument
new_bias
{
{
bias
.
get_shape
()
.
type
(),
{
bias
.
get_shape
().
elements
()}}
};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
auto
gamma2
,
...
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
bias2
(
c
)
-
(
gamma2
(
c
)
*
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
new_bias2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
// Replace convolution instruction with updated weights
...
...
src/include/migraphx/check_shapes.hpp
View file @
15ba8a36
...
...
@@ -18,6 +18,11 @@ struct check_shapes
{
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
())
{
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
())
{}
template
<
class
Op
>
...
...
@@ -119,6 +124,13 @@ struct check_shapes
return
*
this
;
}
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
template
<
class
F
>
bool
same
(
F
f
)
const
{
...
...
src/include/migraphx/operators.hpp
View file @
15ba8a36
...
...
@@ -56,6 +56,9 @@ struct batch_norm_inference
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
only_dims
(
4
);
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
().
elements
(
inputs
.
front
().
lens
()[
1
]);
return
inputs
.
front
();
}
};
...
...
src/include/migraphx/ranges.hpp
View file @
15ba8a36
...
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
...
...
src/targets/cpu/lowering.cpp
View file @
15ba8a36
...
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
assert
((
variance
(
c
)
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
(
c
)
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
(
c
)
)
/
std
::
sqrt
(
variance
(
c
)
+
epsilon
)
+
bias
(
c
)
;
assert
((
variance
[
c
]
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
[
c
]
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
[
c
]
)
/
std
::
sqrt
(
variance
[
c
]
+
epsilon
)
+
bias
[
c
]
;
});
});
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
15ba8a36
...
...
@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
if
(
conv
.
op
.
group
>
1
)
return
false
;
if
(
conv
.
op
.
padding_mode
!=
op
::
padding_mode_t
::
default_
)
return
false
;
if
(
wei
.
lens
()[
1
]
>
512
and
conv
.
algo
!=
miopenConvolutionFwdAlgoWinograd
)
return
false
;
auto
op
=
conv
.
op
;
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
15ba8a36
...
...
@@ -3,9 +3,13 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
bool
is_batch_norm
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"batch_norm_inference"
;
}
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
std
::
vector
<
float
>
xdata
=
{
...
...
@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
}
TEST_CASE
(
non_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p2
,
&
is_batch_norm
));
}
TEST_CASE
(
as_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
TEST_CASE
(
literal_reshape
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
reshape
=
[
&
](
auto
ins
)
{
return
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
1
,
1
}},
ins
);
};
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
))));
auto
bias
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
))));
auto
mean
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
))));
auto
variance
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
))));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment