Unverified Commit d240b748 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[perf] nn.SyncBatchNorm: use autograd function to save memory (#680)

parent 5be4817d
...@@ -9,51 +9,29 @@ import torch ...@@ -9,51 +9,29 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
if torch.__version__.split(".")[:2] >= ["1", "8"]:
from torch.distributed.nn.functional import all_reduce as differentiable_all_reduce
else:
# Copied from https://github.com/pytorch/pytorch/blob/v1.8.1/torch/distributed/nn/functional.py
class _AllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, op, group, tensor): # type: ignore
ctx.group = group
ctx.op = op
tensor = tensor.clone()
dist.all_reduce(tensor, op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output): # type: ignore
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
def differentiable_all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD): # type: ignore
return _AllReduce.apply(op, group, tensor)
def _forward( def _forward(
input: torch.Tensor, input: torch.Tensor,
affine: bool, affine: bool,
track_running_stats: bool, track_running_stats: bool,
mean: torch.Tensor, mean: torch.Tensor,
meansqr: torch.Tensor, var: torch.Tensor,
invstd: torch.Tensor,
momentum: float, momentum: float,
eps: float,
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
running_mean: torch.Tensor, running_mean: torch.Tensor,
running_var: torch.Tensor, running_var: torch.Tensor,
total_count: torch.Tensor, total_count: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
var = meansqr - mean * mean
if track_running_stats: if track_running_stats:
with torch.no_grad(): with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1)) unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean) running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var) running_var += momentum * (unbiased_var.reshape(-1) - running_var)
invstd = torch.rsqrt(var + eps)
if affine: if affine:
return (input - mean) * invstd * weight.reshape(mean.shape) + bias.reshape(mean.shape) return (input - mean) * (invstd * weight.reshape_as(mean)) + bias.reshape_as(mean)
else: else:
return (input - mean) * invstd return (input - mean) * invstd
...@@ -62,6 +40,92 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]: ...@@ -62,6 +40,92 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]:
_forward = torch.jit.script(_forward) # type: ignore _forward = torch.jit.script(_forward) # type: ignore
class _SyncBatchNormFunction(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx, input, weight, bias, affine, track_running_stats, running_mean, running_var, eps, momentum, process_group
):
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
all_reduce_handle.wait()
vec = vec * (count / total_count)
dist.all_reduce(vec, group=process_group)
mean, meansqr = vec.chunk(2)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + eps)
ctx.save_for_backward(input, weight, bias, mean, invstd, total_count)
ctx.process_group = process_group
return _forward(
input,
affine,
track_running_stats,
mean,
var,
invstd,
momentum,
weight,
bias,
running_mean,
running_var,
total_count,
)
@staticmethod
# type: ignore
def backward(ctx, grad_output):
needs_input_grad = ctx.needs_input_grad[0]
needs_weight_grad = ctx.needs_input_grad[1]
grad_input = None
grad_weight = None
grad_bias = None
input, weight, bias, mean, invstd, total_count = ctx.saved_tensors
process_group = ctx.process_group
dim = [d for d in range(input.ndim) if d != 1]
if needs_input_grad or needs_weight_grad:
grad_common = torch.sum(
(input - mean) * grad_output, dim=dim, keepdim=True
) # common to grad_weight and grad_invstd
if needs_input_grad:
if weight is None: # i.e. affine is False
grad_input = invstd * grad_output
grad_mean = -torch.sum(grad_input, dim=dim, keepdim=True)
grad_invstd = grad_common
else:
grad_input = (invstd * weight.reshape_as(mean)) * grad_output
grad_mean = -torch.sum(grad_input, dim=dim, keepdim=True)
grad_invstd = grad_common * weight.reshape_as(mean)
grad_var = -0.5 * invstd.pow(3) * grad_invstd
grad_mean += -2 * mean * grad_var
grad_meansqr = grad_var
vec = torch.cat([grad_mean, grad_meansqr])
all_reduce_handle = dist.all_reduce(vec, group=process_group, async_op=True)
if needs_weight_grad:
grad_weight = (grad_common * invstd).resize_as(weight)
grad_bias = torch.sum(grad_output, dim=dim)
if needs_input_grad:
all_reduce_handle.wait()
vec = vec / total_count # NOTE(msb) removed '* count' here to avoid '/ count' below
grad_mean, grad_meansqr = vec.chunk(2)
grad_input += grad_mean # removed '/ count'
grad_input += input * (2 * grad_meansqr) # removed '/ count'
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
class SyncBatchNorm(torch.nn.BatchNorm2d): class SyncBatchNorm(torch.nn.BatchNorm2d):
""" """
Fast re-implementation of ``torch.nn.SyncBatchNorm`` that can achieve a speedup Fast re-implementation of ``torch.nn.SyncBatchNorm`` that can achieve a speedup
...@@ -79,30 +143,17 @@ class SyncBatchNorm(torch.nn.BatchNorm2d): ...@@ -79,30 +143,17 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
if not dist.is_initialized() or not self.training: if not dist.is_initialized() or not self.training:
return super().forward(input) return super().forward(input)
dim = [d for d in range(input.ndim) if d != 1] return _SyncBatchNormFunction.apply(
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
handle = dist.all_reduce(total_count, group=self._process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
handle.wait()
vec = vec * (count / total_count)
mean, meansqr = differentiable_all_reduce(vec, group=self._process_group).chunk(2) # type: ignore
return _forward(
input, input,
self.affine,
self.track_running_stats,
mean,
meansqr,
self.momentum,
self.eps,
self.weight, self.weight,
self.bias, self.bias,
self.affine,
self.track_running_stats,
self.running_mean, self.running_mean,
self.running_var, self.running_var,
total_count, self.eps,
self.momentum,
self._process_group,
) )
@classmethod @classmethod
......
...@@ -36,40 +36,48 @@ def pg_test(world_size=torch.cuda.device_count()): ...@@ -36,40 +36,48 @@ def pg_test(world_size=torch.cuda.device_count()):
def check_parity(torch_bn, fs_bn, x): def check_parity(torch_bn, fs_bn, x):
yh = torch.ones_like(x) yh = torch.randn_like(x)
torch_y = torch_bn(x) torch_x = x.detach()
fs_y = fs_bn(x) torch_x.requires_grad = True
torch_y = torch_bn(torch_x)
torch_y.backward(yh) torch_y.backward(yh)
fs_x = x.detach()
fs_x.requires_grad = True
fs_y = fs_bn(fs_x)
fs_y.backward(yh) fs_y.backward(yh)
assert torch.allclose(torch_y, fs_y), f"{torch_y} != {fs_y}" torch.testing.assert_allclose(torch_y, fs_y)
assert torch.allclose(torch_bn.running_mean, fs_bn.running_mean), f"{torch_bn.running_mean} != {fs_bn.running_mean}" torch.testing.assert_allclose(torch_bn.running_mean, fs_bn.running_mean)
assert torch.allclose(torch_bn.running_var, fs_bn.running_var), f"{torch_bn.running_var} != {fs_bn.running_var}" torch.testing.assert_allclose(torch_bn.running_var, fs_bn.running_var)
assert torch.allclose(torch_bn.weight, fs_bn.weight), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}" torch.testing.assert_allclose(torch_bn.weight, fs_bn.weight)
assert torch.allclose(torch_bn.bias, fs_bn.bias), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}" torch.testing.assert_allclose(torch_bn.bias, fs_bn.bias)
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796 torch.testing.assert_allclose(torch_bn.weight.grad, fs_bn.weight.grad)
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight.grad), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}" torch.testing.assert_allclose(torch_bn.bias.grad, fs_bn.bias.grad)
assert torch.allclose(torch_bn.bias.grad, fs_bn.bias.grad), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}" torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
def check_parity_ddp(torch_bn, fs_bn, x): def check_parity_ddp(torch_bn, fs_bn, x):
yh = torch.ones_like(x) yh = torch.randn_like(x)
rank = dist.get_rank() rank = dist.get_rank()
torch_ddp = DDP(torch_bn, device_ids=[rank]) torch_ddp = DDP(torch_bn, device_ids=[rank])
fs_ddp = DDP(fs_bn, device_ids=[rank])
torch_bn = torch_ddp.module torch_bn = torch_ddp.module
fs_bn = fs_ddp.module torch_x = x.detach()
torch_y = torch_ddp(x) torch_x.requires_grad = True
fs_y = fs_ddp(x) torch_y = torch_ddp(torch_x)
torch_y.backward(yh) torch_y.backward(yh)
fs_ddp = DDP(fs_bn, device_ids=[rank])
fs_bn = fs_ddp.module
fs_x = x.detach()
fs_x.requires_grad = True
fs_y = fs_ddp(fs_x)
fs_y.backward(yh) fs_y.backward(yh)
assert torch.allclose(torch_y, fs_y), f"{torch_y} != {fs_y}" torch.testing.assert_allclose(torch_y, fs_y)
assert torch.allclose(torch_bn.running_mean, fs_bn.running_mean), f"{torch_bn.running_mean} != {fs_bn.running_mean}" torch.testing.assert_allclose(torch_bn.running_mean, fs_bn.running_mean)
assert torch.allclose(torch_bn.running_var, fs_bn.running_var), f"{torch_bn.running_var} != {fs_bn.running_var}" torch.testing.assert_allclose(torch_bn.running_var, fs_bn.running_var)
assert torch.allclose(torch_bn.weight, fs_bn.weight), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}" torch.testing.assert_allclose(torch_bn.weight, fs_bn.weight)
assert torch.allclose(torch_bn.bias, fs_bn.bias), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}" torch.testing.assert_allclose(torch_bn.bias, fs_bn.bias)
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796 torch.testing.assert_allclose(torch_bn.weight.grad, fs_bn.weight.grad)
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight.grad), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}" torch.testing.assert_allclose(torch_bn.bias.grad, fs_bn.bias.grad)
assert torch.allclose(torch_bn.bias.grad, fs_bn.bias.grad), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}" torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
@pg_test(world_size=1) @pg_test(world_size=1)
...@@ -142,3 +150,30 @@ def parity1d_syncbn(): ...@@ -142,3 +150,30 @@ def parity1d_syncbn():
torch_bn = torch.nn.SyncBatchNorm(3).cuda() torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda() fs_bn = SyncBatchNorm(3).cuda()
check_parity_ddp(torch_bn, fs_bn, x) check_parity_ddp(torch_bn, fs_bn, x)
@pg_test()
def memory_allocated():
rank = dist.get_rank()
torch.cuda.set_device(rank)
x = torch.randn(50, 2048, 7, 7).to(rank)
torch_bn = torch.nn.SyncBatchNorm(2048).cuda()
torch_bn = DDP(torch_bn, device_ids=[rank])
fs_bn = SyncBatchNorm(2048).cuda()
fs_bn = DDP(fs_bn, device_ids=[rank])
torch_x = x.detach()
torch_x.requires_grad = True
fs_x = x.detach()
fs_x.requires_grad = True
torch.cuda.empty_cache()
mem_at_start = torch.cuda.memory_stats()["allocated_bytes.all.current"]
torch_y = torch_bn(torch_x)
torch.cuda.empty_cache()
mem_after_torch = torch.cuda.memory_stats()["allocated_bytes.all.current"]
fs_y = fs_bn(fs_x)
torch.cuda.empty_cache()
mem_final = torch.cuda.memory_stats()["allocated_bytes.all.current"]
torch_used = mem_after_torch - mem_at_start
fs_used = mem_final - mem_after_torch
assert fs_used < (torch_used * 1.01), f"{fs_used} < {torch_used * 1.01}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment