Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
443fa76e
Commit
443fa76e
authored
Jan 14, 2019
by
Jie
Browse files
[sync BN nhwc]
Added kernel to support sync BN for channel last tensor
parent
3c7a0e44
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
929 additions
and
134 deletions
+929
-134
apex/parallel/optimized_sync_batchnorm.py
apex/parallel/optimized_sync_batchnorm.py
+21
-4
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+36
-13
csrc/syncbn.cpp
csrc/syncbn.cpp
+57
-16
csrc/welford.cu
csrc/welford.cu
+774
-87
tests/synced_batchnorm/single_gpu_unit_test.py
tests/synced_batchnorm/single_gpu_unit_test.py
+34
-7
tests/synced_batchnorm/two_gpu_unit_test.py
tests/synced_batchnorm/two_gpu_unit_test.py
+7
-7
No files found.
apex/parallel/optimized_sync_batchnorm.py
View file @
443fa76e
...
@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
...
@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
process_group: pass in a process group within which the stats of the
process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process
mini-batch is being synchronized. ``None`` for using default process
group
group
channel_last: a boolean value that when set to ``True``, this module
take the last dimension of the input tensor to be the channel
dimension. Default: False
Examples::
Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
>>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
"""
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
process_group
=
None
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
,
process_group
=
None
,
channel_last
=
False
):
super
(
SyncBatchNorm
,
self
).
__init__
(
num_features
,
eps
=
eps
,
momentum
=
momentum
,
affine
=
affine
,
track_running_stats
=
track_running_stats
)
super
(
SyncBatchNorm
,
self
).
__init__
(
num_features
,
eps
=
eps
,
momentum
=
momentum
,
affine
=
affine
,
track_running_stats
=
track_running_stats
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
channel_last
=
channel_last
def
_specify_process_group
(
self
,
process_group
):
def
_specify_process_group
(
self
,
process_group
):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
def
_specify_channel_last
(
self
,
channel_last
):
self
.
channel_last
=
channel_last
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
not
self
.
training
and
self
.
track_running_stats
:
if
not
self
.
training
and
self
.
track_running_stats
and
not
self
.
channel_last
:
# fall back to pytorch implementation for inference
# fall back to pytorch implementation for inference
return
F
.
batch_norm
(
input
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
False
,
0.0
,
self
.
eps
)
return
F
.
batch_norm
(
input
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
False
,
0.0
,
self
.
eps
)
else
:
else
:
self
.
num_batches_tracked
+=
1
exponential_average_factor
=
0.0
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
track_running_stats
,
self
.
momentum
,
self
.
process_group
)
if
self
.
training
and
self
.
track_running_stats
:
self
.
num_batches_tracked
+=
1
if
self
.
momentum
is
None
:
exponential_average_factor
=
1.0
/
float
(
self
.
num_batches_tracked
)
else
:
exponential_average_factor
=
self
.
momentum
return
SyncBatchnormFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
running_mean
,
self
.
running_var
,
self
.
eps
,
self
.
training
or
not
self
.
track_running_stats
,
exponential_average_factor
,
self
.
process_group
,
self
.
channel_last
)
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
443fa76e
...
@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
...
@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
class
SyncBatchnormFunction
(
Function
):
class
SyncBatchnormFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
,
track_running_stats
=
True
,
momentum
=
1.0
,
process_group
=
None
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
running_mean
,
running_variance
,
eps
,
track_running_stats
=
True
,
momentum
=
1.0
,
process_group
=
None
,
channel_last
=
False
):
torch
.
cuda
.
nvtx
.
range_push
(
"sync_BN_fw"
)
torch
.
cuda
.
nvtx
.
range_push
(
"sync_BN_fw"
)
input
=
input
.
contiguous
()
input
=
input
.
contiguous
()
world_size
=
0
world_size
=
0
mean
=
None
var_biased
=
None
inv_std
=
None
var
=
None
out
=
None
count
=
None
if
track_running_stats
:
if
track_running_stats
:
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
if
channel_last
:
count
=
int
(
input
.
numel
()
/
input
.
size
(
-
1
))
mean
,
var_biased
=
syncbn
.
welford_mean_var_c_last
(
input
)
else
:
count
=
int
(
input
.
numel
()
/
input
.
size
(
1
))
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
input
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
if
not
process_group
:
if
not
process_group
:
process_group
=
torch
.
distributed
.
group
.
WORLD
process_group
=
torch
.
distributed
.
group
.
WORLD
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
.
size
(
0
),
dtype
=
var
.
dtype
,
device
=
var
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var
_biased
.
size
(
0
),
dtype
=
var
_biased
.
dtype
,
device
=
var
_biased
.
device
)
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
mean_l
,
mean
,
process_group
)
torch
.
distributed
.
all_gather
(
mean_l
,
mean
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
,
process_group
)
mean
,
var
,
var_biase
d
=
syncbn
.
welford_parallel
(
mean_all
.
transpose
(
1
,
0
).
contiguous
(),
var_all
.
transpose
(
1
,
0
).
cont
iguous
(),
int
(
input
.
numel
()
/
input
.
size
(
1
))
)
mean
,
var
,
inv_st
d
=
syncbn
.
welford_parallel
(
mean_all
,
var_all
,
co
u
nt
,
eps
)
# TODO(Jie): should do fp32 math instead!
# TODO(Jie): should do fp32 math instead!
else
:
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
var
=
var_biased
*
(
count
)
/
(
count
-
1
)
r_m_inc
=
mean
if
running_mean
.
dtype
!=
torch
.
float16
else
mean
.
half
()
r_m_inc
=
mean
if
running_mean
.
dtype
!=
torch
.
float16
else
mean
.
half
()
r_v_inc
=
var
if
running_variance
.
dtype
!=
torch
.
float16
else
var
.
half
()
r_v_inc
=
var
if
running_variance
.
dtype
!=
torch
.
float16
else
var
.
half
()
...
@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
...
@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
running_variance
.
data
=
running_variance
.
data
*
(
1
-
momentum
)
+
momentum
*
r_v_inc
running_variance
.
data
=
running_variance
.
data
*
(
1
-
momentum
)
+
momentum
*
r_v_inc
else
:
else
:
mean
=
running_mean
.
data
mean
=
running_mean
.
data
var_biased
=
running_var
.
data
inv_std
=
1.0
/
torch
.
sqrt
(
running_var
.
data
+
eps
)
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
var_biased
)
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
inv_std
)
ctx
.
eps
=
eps
ctx
.
process_group
=
process_group
ctx
.
process_group
=
process_group
ctx
.
channel_last
=
channel_last
ctx
.
world_size
=
world_size
ctx
.
world_size
=
world_size
out
=
syncbn
.
batchnorm_forward
(
input
,
mean
,
var_biased
,
weight
,
bias
,
eps
)
if
channel_last
:
out
=
syncbn
.
batchnorm_forward_c_last
(
input
,
mean
,
inv_std
,
weight
,
bias
)
else
:
out
=
syncbn
.
batchnorm_forward
(
input
,
mean
,
inv_std
,
weight
,
bias
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
return
out
return
out
...
@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
...
@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input
,
weight
,
running_mean
,
running_variance
=
ctx
.
saved_tensors
saved_input
,
weight
,
mean
,
inv_std
=
ctx
.
saved_tensors
eps
=
ctx
.
eps
process_group
=
ctx
.
process_group
process_group
=
ctx
.
process_group
channel_last
=
ctx
.
channel_last
world_size
=
ctx
.
world_size
world_size
=
ctx
.
world_size
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
# TODO(jie): why do I have to clone here? life time of grad_output?
# TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
eps
)
if
channel_last
:
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
else
:
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
# calculate grad_input
# calculate grad_input
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
...
@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
...
@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
ReduceOp
.
SUM
,
process_group
)
mean_dy_xmu
,
ReduceOp
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
mean_dy_xmu
=
mean_dy_xmu
/
world_size
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
running_mean
,
running_variance
,
weight
,
mean_dy
,
mean_dy_xmu
,
eps
)
if
channel_last
:
grad_input
=
syncbn
.
batchnorm_backward_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean_dy
,
mean_dy_xmu
)
else
:
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean_dy
,
mean_dy_xmu
)
if
weight
is
None
or
not
ctx
.
needs_input_grad
[
1
]:
if
weight
is
None
or
not
ctx
.
needs_input_grad
[
1
]:
grad_weight
=
None
grad_weight
=
None
...
@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
...
@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
grad_bias
=
None
grad_bias
=
None
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
csrc/syncbn.cpp
View file @
443fa76e
...
@@ -3,52 +3,93 @@
...
@@ -3,52 +3,93 @@
#include <vector>
#include <vector>
// returns {mean,
unbiased_var,
biased_var}
// returns {mean,biased_var}
// implemented using welford
// implemented using welford
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
);
std
::
vector
<
at
::
Tensor
>
welford_mean_var_CUDA
(
const
at
::
Tensor
input
);
// reduces array of mean/var across processes
// reduces array of mean/var across processes
// returns global {mean,
unbiased_var
,biased_var}
// returns global {mean,
inv_std
,biased_var}
// implemented using welford
// implemented using welford
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased_feature_nodes
,
int
numel
);
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased_feature_nodes
,
int
numel
,
const
float
eps
);
// elementwise BN operation, returns output
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// input/weight/shift should have identical data type;
// mean/
var
have promoted data type (dtype==fp16?fp32:dtype)
// mean/
inv_std
have promoted data type (dtype==fp16?fp32:dtype)
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
at
::
Tensor
batchnorm_forward_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
,
const
at
::
Tensor
shift
);
const
float
eps
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// grad_output/input should have identical data type;
// mean/
var
have promoted data type (dtype==fp16?fp32:dtype)
// mean/
inv_std
have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
// implemented using kahan summation
std
::
vector
<
at
::
Tensor
>
reduce_bn_CUDA
(
const
at
::
Tensor
grad_output
,
std
::
vector
<
at
::
Tensor
>
reduce_bn_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
weight
);
const
float
eps
);
// elementwise backward BN operation, returns grad_input
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// grad_output/input/weight precision could be fp16/fp32;
// mean/
var
/mean_dy/mean_dy_xmu precision is fp32
// mean/
inv_std
/mean_dy/mean_dy_xmu precision is fp32
at
::
Tensor
batchnorm_backward_CUDA
(
const
at
::
Tensor
grad_output
,
at
::
Tensor
batchnorm_backward_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
var
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
,
const
at
::
Tensor
mean_dy_xmu
);
const
float
eps
);
// returns {mean, biased_var}
// implemented using welford
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std
::
vector
<
at
::
Tensor
>
welford_mean_var_c_last_CUDA
(
const
at
::
Tensor
input
);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at
::
Tensor
batchnorm_forward_c_last_CUDA
(
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
shift
);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std
::
vector
<
at
::
Tensor
>
reduce_bn_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at
::
Tensor
batchnorm_backward_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
Tensor
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"welford_mean_var"
,
&
welford_mean_var_CUDA
,
"welford mean variance"
);
m
.
def
(
"welford_mean_var"
,
&
welford_mean_var_CUDA
,
"welford mean variance"
);
m
.
def
(
"welford_parallel"
,
&
welford_parallel_CUDA
,
"welford parallel reduce mean variance"
);
m
.
def
(
"welford_parallel"
,
&
welford_parallel_CUDA
,
"welford parallel reduce mean variance"
);
m
.
def
(
"batchnorm_forward"
,
&
batchnorm_forward_CUDA
,
"batchnorm forward"
);
m
.
def
(
"batchnorm_forward"
,
&
batchnorm_forward_CUDA
,
"batchnorm forward"
);
m
.
def
(
"reduce_bn"
,
&
reduce_bn_CUDA
,
"batchnorm backward reduce grad sum and bias/weight grad
ient
"
);
m
.
def
(
"reduce_bn"
,
&
reduce_bn_CUDA
,
"batchnorm backward reduce grad sum and bias/weight grad"
);
m
.
def
(
"batchnorm_backward"
,
&
batchnorm_backward_CUDA
,
"batchnorm backward dgrad"
);
m
.
def
(
"batchnorm_backward"
,
&
batchnorm_backward_CUDA
,
"batchnorm backward dgrad"
);
m
.
def
(
"welford_mean_var_c_last"
,
&
welford_mean_var_c_last_CUDA
,
"welford mean variance nhwc"
);
m
.
def
(
"batchnorm_forward_c_last"
,
&
batchnorm_forward_c_last_CUDA
,
"batchnorm forward nhwc"
);
m
.
def
(
"reduce_bn_c_last"
,
&
reduce_bn_c_last_CUDA
,
"batchnorm backwards reduce grad sum and bias/weight grad nhwc"
);
m
.
def
(
"batchnorm_backward_c_last"
,
&
batchnorm_backward_c_last_CUDA
,
"batchnorm backward dgrad nhwc"
);
}
}
csrc/welford.cu
View file @
443fa76e
This diff is collapsed.
Click to expand it.
tests/synced_batchnorm/single_gpu_unit_test.py
View file @
443fa76e
...
@@ -54,7 +54,11 @@ m = inp_r.mean(1)
...
@@ -54,7 +54,11 @@ m = inp_r.mean(1)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
eps
=
1e-5
#mean, var, var_biased = syncbn.welford_mean_var(inp_t)
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
.
momentum
=
1.0
bn
.
momentum
=
1.0
...
@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
...
@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn
=
sbn
(
inp_sbn
)
out_sbn
=
sbn
(
inp_sbn
)
out_sbn
.
backward
(
grad_sbn
)
out_sbn
.
backward
(
grad_sbn
)
sbn_c_last
=
apex
.
parallel
.
SyncBatchNorm
(
feature_size
,
channel_last
=
True
).
cuda
()
sbn_c_last
.
momentum
=
1.0
sbn_c_last
.
weight
.
data
=
weight_t
.
clone
()
sbn_c_last
.
bias
.
data
=
bias_t
.
clone
()
inp_sbn_c_last
=
inp_t
.
clone
().
transpose
(
-
1
,
1
).
contiguous
().
requires_grad_
()
grad_sbn_c_last
=
grad_output_t
.
clone
().
transpose
(
-
1
,
1
).
contiguous
().
detach
()
out_sbn_c_last
=
sbn_c_last
(
inp_sbn_c_last
)
out_sbn_c_last
.
backward
(
grad_sbn_c_last
)
sbn_result
=
True
sbn_result
=
True
sbn_result_c_last
=
True
bn_result
=
True
bn_result
=
True
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing variance: "
,
var
,
unb_v
,
error
)
and
sbn_result
#
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
eps
=
1e-5
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
var_biase
d
,
weight_t
,
bias_t
,
eps
)
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
inv_st
d
,
weight_t
,
bias_t
)
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
sbn_result
=
compare
(
"comparing output: "
,
out
,
out_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing output: "
,
out
,
out_r
,
error
)
and
sbn_result
...
@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
...
@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
eps
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
,
eps
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
)
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing mean_dy grad: "
,
mean_dy
,
mean_dy_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing mean_dy grad: "
,
mean_dy
,
mean_dy_r
,
error
)
and
sbn_result
...
@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
...
@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
compare
(
"comparing bn input grad: "
,
inp_bn
.
grad
,
grad_input_r
,
error
)
compare
(
"comparing bn input grad: "
,
inp_bn
.
grad
,
grad_input_r
,
error
)
sbn_result
=
compare
(
"comparing sbn input grad: "
,
inp_sbn
.
grad
,
grad_input_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing sbn input grad: "
,
inp_sbn
.
grad
,
grad_input_r
,
error
)
and
sbn_result
compare
(
"comparing output: "
,
out_bn
,
out_sbn
,
error
)
compare
(
"comparing
bn/sbn
output: "
,
out_bn
,
out_sbn
,
error
)
sbn_result
=
compare
(
"comparing running_mean: "
,
bn
.
running_mean
.
data
,
sbn
.
running_mean
.
data
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing running_mean: "
,
bn
.
running_mean
.
data
,
sbn
.
running_mean
.
data
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing running_variance: "
,
bn
.
running_var
.
data
,
sbn
.
running_var
.
data
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing running_variance: "
,
bn
.
running_var
.
data
,
sbn
.
running_var
.
data
,
error
)
and
sbn_result
compare
(
"comparing grad_input: "
,
inp_bn
.
grad
,
inp_sbn
.
grad
,
error
)
compare
(
"comparing grad_input: "
,
inp_bn
.
grad
,
inp_sbn
.
grad
,
error
)
...
@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
...
@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare
(
"comparing grad_weight bn to ref: "
,
bn
.
weight
.
grad
,
grad_weight_r
,
error
)
compare
(
"comparing grad_weight bn to ref: "
,
bn
.
weight
.
grad
,
grad_weight_r
,
error
)
sbn_result
=
compare
(
"comparing grad_weight sbn to ref: "
,
sbn
.
weight
.
grad
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing grad_weight sbn to ref: "
,
sbn
.
weight
.
grad
,
grad_weight_r
,
error
)
and
sbn_result
compare
(
"comparing channel last bn/sbn output: "
,
out_bn
,
out_sbn_c_last
.
transpose
(
-
1
,
1
).
contiguous
(),
error
)
sbn_result_c_last
=
compare
(
"comparing channel last running_mean: "
,
bn
.
running_mean
.
data
,
sbn_c_last
.
running_mean
.
data
,
error
)
and
sbn_result_c_last
sbn_result_c_last
=
compare
(
"comparing channel last running_variance: "
,
bn
.
running_var
.
data
,
sbn_c_last
.
running_var
.
data
,
error
)
and
sbn_result_c_last
compare
(
"comparing channel last grad_input: "
,
inp_bn
.
grad
,
inp_sbn_c_last
.
grad
.
transpose
(
-
1
,
1
).
contiguous
(),
error
)
compare
(
"comparing channel last grad_bias: "
,
bn
.
bias
.
grad
,
sbn_c_last
.
bias
.
grad
,
error
)
sbn_result_c_last
=
compare
(
"comparing channel last grad_bias sbn to ref: "
,
sbn_c_last
.
bias
.
grad
,
grad_bias_r
,
error
)
and
sbn_result_c_last
compare
(
"comparing channel last grad_weight: "
,
bn
.
weight
.
grad
,
sbn_c_last
.
weight
.
grad
,
error
)
sbn_result_c_last
=
compare
(
"comparing channel last grad_weight sbn to ref: "
,
sbn_c_last
.
weight
.
grad
,
grad_weight_r
,
error
)
and
sbn_result_c_last
if
sbn_result
:
if
sbn_result
:
print
(
"====SBN single gpu passed tests"
)
print
(
"====SBN single gpu passed tests"
)
else
:
else
:
print
(
"*SBN single gpu failed*"
)
print
(
"*SBN single gpu failed*"
)
if
sbn_result_c_last
:
print
(
"====SBN channel last single gpu passed tests"
)
else
:
print
(
"*SBN channel last single gpu failed*"
)
tests/synced_batchnorm/two_gpu_unit_test.py
View file @
443fa76e
...
@@ -75,7 +75,10 @@ m = inp_r.mean(1)
...
@@ -75,7 +75,10 @@ m = inp_r.mean(1)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
b_v
=
inp_r
.
var
(
1
,
unbiased
=
False
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
unb_v
=
inp_r
.
var
(
1
,
unbiased
=
True
)
mean
,
var
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
eps
=
1e-5
mean
,
var_biased
=
syncbn
.
welford_mean_var
(
inp_t
)
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
=
torch
.
nn
.
BatchNorm2d
(
feature_size
).
cuda
()
bn
.
momentum
=
1.0
bn
.
momentum
=
1.0
...
@@ -111,12 +114,9 @@ bn_result = True
...
@@ -111,12 +114,9 @@ bn_result = True
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing mean: "
,
mean
,
m
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing variance: "
,
var
,
unb_v
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing biased variance: "
,
var_biased
,
b_v
,
error
)
and
sbn_result
eps
=
1e-5
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
inv_std
,
weight_t
,
bias_t
)
out
=
syncbn
.
batchnorm_forward
(
inp_t
,
mean
,
var_biased
,
weight_t
,
bias_t
,
eps
)
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
out_r
=
weight_r
*
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
+
bias_r
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
...
@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
...
@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
eps
)
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
var_biase
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
,
eps
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_st
d
,
weight_t
,
mean_dy
,
mean_dy_xmu
)
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment