Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d240b748
Unverified
Commit
d240b748
authored
May 14, 2021
by
msbaines
Committed by
GitHub
May 14, 2021
Browse files
[perf] nn.SyncBatchNorm: use autograd function to save memory (#680)
parent
5be4817d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
154 additions
and
68 deletions
+154
-68
fairscale/experimental/nn/sync_batchnorm.py
fairscale/experimental/nn/sync_batchnorm.py
+95
-44
tests/experimental/nn/test_sync_batchnorm.py
tests/experimental/nn/test_sync_batchnorm.py
+59
-24
No files found.
fairscale/experimental/nn/sync_batchnorm.py
View file @
d240b748
...
@@ -9,51 +9,29 @@ import torch
...
@@ -9,51 +9,29 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
if
torch
.
__version__
.
split
(
"."
)[:
2
]
>=
[
"1"
,
"8"
]:
from
torch.distributed.nn.functional
import
all_reduce
as
differentiable_all_reduce
else
:
# Copied from https://github.com/pytorch/pytorch/blob/v1.8.1/torch/distributed/nn/functional.py
class
_AllReduce
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
op
,
group
,
tensor
):
# type: ignore
ctx
.
group
=
group
ctx
.
op
=
op
tensor
=
tensor
.
clone
()
dist
.
all_reduce
(
tensor
,
op
=
op
,
group
=
group
)
return
tensor
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# type: ignore
return
(
None
,
None
)
+
(
_AllReduce
.
apply
(
ctx
.
op
,
ctx
.
group
,
grad_output
),)
def
differentiable_all_reduce
(
tensor
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
dist
.
group
.
WORLD
):
# type: ignore
return
_AllReduce
.
apply
(
op
,
group
,
tensor
)
def
_forward
(
def
_forward
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
affine
:
bool
,
affine
:
bool
,
track_running_stats
:
bool
,
track_running_stats
:
bool
,
mean
:
torch
.
Tensor
,
mean
:
torch
.
Tensor
,
meansqr
:
torch
.
Tensor
,
var
:
torch
.
Tensor
,
invstd
:
torch
.
Tensor
,
momentum
:
float
,
momentum
:
float
,
eps
:
float
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
running_mean
:
torch
.
Tensor
,
running_mean
:
torch
.
Tensor
,
running_var
:
torch
.
Tensor
,
running_var
:
torch
.
Tensor
,
total_count
:
torch
.
Tensor
,
total_count
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
var
=
meansqr
-
mean
*
mean
if
track_running_stats
:
if
track_running_stats
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
unbiased_var
=
var
*
(
total_count
/
(
total_count
-
1
))
unbiased_var
=
var
*
(
total_count
/
(
total_count
-
1
))
running_mean
+=
momentum
*
(
mean
.
reshape
(
-
1
)
-
running_mean
)
running_mean
+=
momentum
*
(
mean
.
reshape
(
-
1
)
-
running_mean
)
running_var
+=
momentum
*
(
unbiased_var
.
reshape
(
-
1
)
-
running_var
)
running_var
+=
momentum
*
(
unbiased_var
.
reshape
(
-
1
)
-
running_var
)
invstd
=
torch
.
rsqrt
(
var
+
eps
)
if
affine
:
if
affine
:
return
(
input
-
mean
)
*
invstd
*
weight
.
reshape
(
mean
.
shape
)
+
bias
.
reshape
(
mean
.
shape
)
return
(
input
-
mean
)
*
(
invstd
*
weight
.
reshape
_as
(
mean
)
)
+
bias
.
reshape
_as
(
mean
)
else
:
else
:
return
(
input
-
mean
)
*
invstd
return
(
input
-
mean
)
*
invstd
...
@@ -62,6 +40,92 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]:
...
@@ -62,6 +40,92 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]:
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
class
_SyncBatchNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
input
,
weight
,
bias
,
affine
,
track_running_stats
,
running_mean
,
running_var
,
eps
,
momentum
,
process_group
):
dim
=
[
d
for
d
in
range
(
input
.
ndim
)
if
d
!=
1
]
count
=
torch
.
full
((
1
,),
input
.
numel
()
//
input
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
total_count
=
count
.
clone
()
all_reduce_handle
=
dist
.
all_reduce
(
total_count
,
group
=
process_group
,
async_op
=
True
)
mean
=
torch
.
mean
(
input
,
dim
=
dim
,
keepdim
=
True
)
meansqr
=
torch
.
mean
(
input
*
input
,
dim
=
dim
,
keepdim
=
True
)
vec
=
torch
.
cat
([
mean
,
meansqr
])
all_reduce_handle
.
wait
()
vec
=
vec
*
(
count
/
total_count
)
dist
.
all_reduce
(
vec
,
group
=
process_group
)
mean
,
meansqr
=
vec
.
chunk
(
2
)
var
=
meansqr
-
mean
*
mean
invstd
=
torch
.
rsqrt
(
var
+
eps
)
ctx
.
save_for_backward
(
input
,
weight
,
bias
,
mean
,
invstd
,
total_count
)
ctx
.
process_group
=
process_group
return
_forward
(
input
,
affine
,
track_running_stats
,
mean
,
var
,
invstd
,
momentum
,
weight
,
bias
,
running_mean
,
running_var
,
total_count
,
)
@
staticmethod
# type: ignore
def
backward
(
ctx
,
grad_output
):
needs_input_grad
=
ctx
.
needs_input_grad
[
0
]
needs_weight_grad
=
ctx
.
needs_input_grad
[
1
]
grad_input
=
None
grad_weight
=
None
grad_bias
=
None
input
,
weight
,
bias
,
mean
,
invstd
,
total_count
=
ctx
.
saved_tensors
process_group
=
ctx
.
process_group
dim
=
[
d
for
d
in
range
(
input
.
ndim
)
if
d
!=
1
]
if
needs_input_grad
or
needs_weight_grad
:
grad_common
=
torch
.
sum
(
(
input
-
mean
)
*
grad_output
,
dim
=
dim
,
keepdim
=
True
)
# common to grad_weight and grad_invstd
if
needs_input_grad
:
if
weight
is
None
:
# i.e. affine is False
grad_input
=
invstd
*
grad_output
grad_mean
=
-
torch
.
sum
(
grad_input
,
dim
=
dim
,
keepdim
=
True
)
grad_invstd
=
grad_common
else
:
grad_input
=
(
invstd
*
weight
.
reshape_as
(
mean
))
*
grad_output
grad_mean
=
-
torch
.
sum
(
grad_input
,
dim
=
dim
,
keepdim
=
True
)
grad_invstd
=
grad_common
*
weight
.
reshape_as
(
mean
)
grad_var
=
-
0.5
*
invstd
.
pow
(
3
)
*
grad_invstd
grad_mean
+=
-
2
*
mean
*
grad_var
grad_meansqr
=
grad_var
vec
=
torch
.
cat
([
grad_mean
,
grad_meansqr
])
all_reduce_handle
=
dist
.
all_reduce
(
vec
,
group
=
process_group
,
async_op
=
True
)
if
needs_weight_grad
:
grad_weight
=
(
grad_common
*
invstd
).
resize_as
(
weight
)
grad_bias
=
torch
.
sum
(
grad_output
,
dim
=
dim
)
if
needs_input_grad
:
all_reduce_handle
.
wait
()
vec
=
vec
/
total_count
# NOTE(msb) removed '* count' here to avoid '/ count' below
grad_mean
,
grad_meansqr
=
vec
.
chunk
(
2
)
grad_input
+=
grad_mean
# removed '/ count'
grad_input
+=
input
*
(
2
*
grad_meansqr
)
# removed '/ count'
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
SyncBatchNorm
(
torch
.
nn
.
BatchNorm2d
):
class
SyncBatchNorm
(
torch
.
nn
.
BatchNorm2d
):
"""
"""
Fast re-implementation of ``torch.nn.SyncBatchNorm`` that can achieve a speedup
Fast re-implementation of ``torch.nn.SyncBatchNorm`` that can achieve a speedup
...
@@ -79,30 +143,17 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
...
@@ -79,30 +143,17 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
if
not
dist
.
is_initialized
()
or
not
self
.
training
:
if
not
dist
.
is_initialized
()
or
not
self
.
training
:
return
super
().
forward
(
input
)
return
super
().
forward
(
input
)
dim
=
[
d
for
d
in
range
(
input
.
ndim
)
if
d
!=
1
]
return
_SyncBatchNormFunction
.
apply
(
count
=
torch
.
full
((
1
,),
input
.
numel
()
//
input
.
size
(
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
total_count
=
count
.
clone
()
handle
=
dist
.
all_reduce
(
total_count
,
group
=
self
.
_process_group
,
async_op
=
True
)
mean
=
torch
.
mean
(
input
,
dim
=
dim
,
keepdim
=
True
)
meansqr
=
torch
.
mean
(
input
*
input
,
dim
=
dim
,
keepdim
=
True
)
vec
=
torch
.
cat
([
mean
,
meansqr
])
handle
.
wait
()
vec
=
vec
*
(
count
/
total_count
)
mean
,
meansqr
=
differentiable_all_reduce
(
vec
,
group
=
self
.
_process_group
).
chunk
(
2
)
# type: ignore
return
_forward
(
input
,
input
,
self
.
affine
,
self
.
track_running_stats
,
mean
,
meansqr
,
self
.
momentum
,
self
.
eps
,
self
.
weight
,
self
.
weight
,
self
.
bias
,
self
.
bias
,
self
.
affine
,
self
.
track_running_stats
,
self
.
running_mean
,
self
.
running_mean
,
self
.
running_var
,
self
.
running_var
,
total_count
,
self
.
eps
,
self
.
momentum
,
self
.
_process_group
,
)
)
@
classmethod
@
classmethod
...
...
tests/experimental/nn/test_sync_batchnorm.py
View file @
d240b748
...
@@ -36,40 +36,48 @@ def pg_test(world_size=torch.cuda.device_count()):
...
@@ -36,40 +36,48 @@ def pg_test(world_size=torch.cuda.device_count()):
def
check_parity
(
torch_bn
,
fs_bn
,
x
):
def
check_parity
(
torch_bn
,
fs_bn
,
x
):
yh
=
torch
.
ones_like
(
x
)
yh
=
torch
.
randn_like
(
x
)
torch_y
=
torch_bn
(
x
)
torch_x
=
x
.
detach
()
fs_y
=
fs_bn
(
x
)
torch_x
.
requires_grad
=
True
torch_y
=
torch_bn
(
torch_x
)
torch_y
.
backward
(
yh
)
torch_y
.
backward
(
yh
)
fs_x
=
x
.
detach
()
fs_x
.
requires_grad
=
True
fs_y
=
fs_bn
(
fs_x
)
fs_y
.
backward
(
yh
)
fs_y
.
backward
(
yh
)
assert
torch
.
allclose
(
torch_y
,
fs_y
)
,
f
"
{
torch_y
}
!=
{
fs_y
}
"
torch
.
testing
.
assert_
allclose
(
torch_y
,
fs_y
)
assert
torch
.
allclose
(
torch_bn
.
running_mean
,
fs_bn
.
running_mean
)
,
f
"
{
torch_bn
.
running_mean
}
!=
{
fs_bn
.
running_mean
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
running_mean
,
fs_bn
.
running_mean
)
assert
torch
.
allclose
(
torch_bn
.
running_var
,
fs_bn
.
running_var
)
,
f
"
{
torch_bn
.
running_var
}
!=
{
fs_bn
.
running_var
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
running_var
,
fs_bn
.
running_var
)
assert
torch
.
allclose
(
torch_bn
.
weight
,
fs_bn
.
weight
)
,
f
"
{
torch_bn
.
weight
.
grad
}
!=
{
fs_bn
.
weight
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
weight
,
fs_bn
.
weight
)
assert
torch
.
allclose
(
torch_bn
.
bias
,
fs_bn
.
bias
)
,
f
"
{
torch_bn
.
bias
.
grad
}
!=
{
fs_bn
.
bias
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
bias
,
fs_bn
.
bias
)
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796
torch
.
testing
.
assert_allclose
(
torch_bn
.
weight
.
grad
,
fs_bn
.
weight
.
grad
)
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight
.grad
)
, f
"{torch_bn.weight.grad} != {fs_bn.weight
.grad
}"
torch
.
testing
.
assert_allclose
(
torch_bn
.
bias
.
grad
,
f
s_bn
.
bias
.
grad
)
assert
torch
.
allclose
(
torch_
bn
.
bias
.
grad
,
fs_
bn
.
bias
.
grad
),
f
"
{
torch_bn
.
bias
.
grad
}
!=
{
fs_bn
.
bias
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_
x
.
grad
,
fs_
x
.
grad
)
def
check_parity_ddp
(
torch_bn
,
fs_bn
,
x
):
def
check_parity_ddp
(
torch_bn
,
fs_bn
,
x
):
yh
=
torch
.
ones
_like
(
x
)
yh
=
torch
.
randn
_like
(
x
)
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
torch_ddp
=
DDP
(
torch_bn
,
device_ids
=
[
rank
])
torch_ddp
=
DDP
(
torch_bn
,
device_ids
=
[
rank
])
fs_ddp
=
DDP
(
fs_bn
,
device_ids
=
[
rank
])
torch_bn
=
torch_ddp
.
module
torch_bn
=
torch_ddp
.
module
fs_bn
=
fs_ddp
.
module
torch_x
=
x
.
detach
()
torch_
y
=
torch_ddp
(
x
)
torch_
x
.
requires_grad
=
True
fs_y
=
fs_ddp
(
x
)
torch_y
=
torch_ddp
(
torch_
x
)
torch_y
.
backward
(
yh
)
torch_y
.
backward
(
yh
)
fs_ddp
=
DDP
(
fs_bn
,
device_ids
=
[
rank
])
fs_bn
=
fs_ddp
.
module
fs_x
=
x
.
detach
()
fs_x
.
requires_grad
=
True
fs_y
=
fs_ddp
(
fs_x
)
fs_y
.
backward
(
yh
)
fs_y
.
backward
(
yh
)
assert
torch
.
allclose
(
torch_y
,
fs_y
)
,
f
"
{
torch_y
}
!=
{
fs_y
}
"
torch
.
testing
.
assert_
allclose
(
torch_y
,
fs_y
)
assert
torch
.
allclose
(
torch_bn
.
running_mean
,
fs_bn
.
running_mean
)
,
f
"
{
torch_bn
.
running_mean
}
!=
{
fs_bn
.
running_mean
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
running_mean
,
fs_bn
.
running_mean
)
assert
torch
.
allclose
(
torch_bn
.
running_var
,
fs_bn
.
running_var
)
,
f
"
{
torch_bn
.
running_var
}
!=
{
fs_bn
.
running_var
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
running_var
,
fs_bn
.
running_var
)
assert
torch
.
allclose
(
torch_bn
.
weight
,
fs_bn
.
weight
)
,
f
"
{
torch_bn
.
weight
.
grad
}
!=
{
fs_bn
.
weight
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
weight
,
fs_bn
.
weight
)
assert
torch
.
allclose
(
torch_bn
.
bias
,
fs_bn
.
bias
)
,
f
"
{
torch_bn
.
bias
.
grad
}
!=
{
fs_bn
.
bias
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_bn
.
bias
,
fs_bn
.
bias
)
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796
torch
.
testing
.
assert_allclose
(
torch_bn
.
weight
.
grad
,
fs_bn
.
weight
.
grad
)
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight
.grad
)
, f
"{torch_bn.weight.grad} != {fs_bn.weight
.grad
}"
torch
.
testing
.
assert_allclose
(
torch_bn
.
bias
.
grad
,
f
s_bn
.
bias
.
grad
)
assert
torch
.
allclose
(
torch_
bn
.
bias
.
grad
,
fs_
bn
.
bias
.
grad
),
f
"
{
torch_bn
.
bias
.
grad
}
!=
{
fs_bn
.
bias
.
grad
}
"
torch
.
testing
.
assert_
allclose
(
torch_
x
.
grad
,
fs_
x
.
grad
)
@
pg_test
(
world_size
=
1
)
@
pg_test
(
world_size
=
1
)
...
@@ -142,3 +150,30 @@ def parity1d_syncbn():
...
@@ -142,3 +150,30 @@ def parity1d_syncbn():
torch_bn
=
torch
.
nn
.
SyncBatchNorm
(
3
).
cuda
()
torch_bn
=
torch
.
nn
.
SyncBatchNorm
(
3
).
cuda
()
fs_bn
=
SyncBatchNorm
(
3
).
cuda
()
fs_bn
=
SyncBatchNorm
(
3
).
cuda
()
check_parity_ddp
(
torch_bn
,
fs_bn
,
x
)
check_parity_ddp
(
torch_bn
,
fs_bn
,
x
)
@
pg_test
()
def
memory_allocated
():
rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
rank
)
x
=
torch
.
randn
(
50
,
2048
,
7
,
7
).
to
(
rank
)
torch_bn
=
torch
.
nn
.
SyncBatchNorm
(
2048
).
cuda
()
torch_bn
=
DDP
(
torch_bn
,
device_ids
=
[
rank
])
fs_bn
=
SyncBatchNorm
(
2048
).
cuda
()
fs_bn
=
DDP
(
fs_bn
,
device_ids
=
[
rank
])
torch_x
=
x
.
detach
()
torch_x
.
requires_grad
=
True
fs_x
=
x
.
detach
()
fs_x
.
requires_grad
=
True
torch
.
cuda
.
empty_cache
()
mem_at_start
=
torch
.
cuda
.
memory_stats
()[
"allocated_bytes.all.current"
]
torch_y
=
torch_bn
(
torch_x
)
torch
.
cuda
.
empty_cache
()
mem_after_torch
=
torch
.
cuda
.
memory_stats
()[
"allocated_bytes.all.current"
]
fs_y
=
fs_bn
(
fs_x
)
torch
.
cuda
.
empty_cache
()
mem_final
=
torch
.
cuda
.
memory_stats
()[
"allocated_bytes.all.current"
]
torch_used
=
mem_after_torch
-
mem_at_start
fs_used
=
mem_final
-
mem_after_torch
assert
fs_used
<
(
torch_used
*
1.01
),
f
"
{
fs_used
}
<
{
torch_used
*
1.01
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment