"src/vscode:/vscode.git/clone" did not exist on "fa736e321d85a49cd761fccc6dd70a66b562aa1c"
Commit a233fceb authored by Myle Ott's avatar Myle Ott
Browse files

Improve memory handling (recover from OOM and periodically empty caching allocator)

parent be274623
......@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
from itertools import cycle, islice
import math
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq import meters, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
......@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
self._max_bsz_seen = 0
def _build_optimizer(self):
if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
......@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([
sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([
grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
......@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output
......@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train()
self.optimizer.zero_grad()
if self._sample is None:
return 0, {}
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output
return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None:
# backward pass
self.loss.backward()
# get model parameters as a flattened (contiguous) tensor
flat_grads = self._flat_model_grads()
# all-reduce grads
nccl.all_reduce(flat_grads)
try:
# backward pass
self.loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# normalize grads
if grad_denom != 0:
flat_grads.div_(grad_denom)
# all-reduce grads and rescale by grad_denom
self._all_reduce_and_rescale_grads(grad_denom)
# clip grads
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
# take an optimization step
self.optimizer.step()
......@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
self.loss = None
return grad_norm
def _model_grads(self):
return [p.grad for p in self.model.parameters() if p.requires_grad]
def _flat_model_grads(self):
grads = self._model_grads()
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel
return self._flat_grads
def _set_model_grads_(self, flat_grads):
grads = self._model_grads()
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
grad.copy_(flat_grads[offset:offset+numel])
offset += numel
assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm = flat_grads.norm()
if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv
flat_grads.div_(coef)
return norm
return grad_norm, oom
def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy grads into buffer_t
offset = 0
for g in buffer:
numel = g.numel()
buffer_t[offset:offset+numel].copy_(g.view(-1))
offset += numel
# all-reduce and rescale
nccl.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
# copy all-reduced buffer back into grads
offset = 0
for g in buffer:
numel = g.numel()
g.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for g in grads:
sz = g.numel() * g.element_size()
if sz > buffer_size:
# grad is bigger than buffer, all-reduce and rescale directly
nccl.all_reduce(g)
g.div_(grad_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples):
"""Do forward pass in parallel."""
......@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, volatile=True)
# forward pass
_sample_sizes, logging_outputs = Future.gen_tuple_list([
_sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
])
assert sum(ooms_fwd) == 0
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
......@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if sample is None:
self._sample = None
else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment