Commit a233fceb authored by Myle Ott's avatar Myle Ott
Browse files

Improve memory handling (recover from OOM and periodically empty caching allocator)

parent be274623
...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing. ...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
""" """
from itertools import cycle, islice from itertools import cycle, islice
import math
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils from fairseq import meters, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG from fairseq.nag import NAG
...@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler # initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler() self.lr_scheduler = self._build_lr_scheduler()
self._max_bsz_seen = 0
def _build_optimizer(self): def _build_optimizer(self):
if self.args.optimizer == 'adagrad': if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr, return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
...@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass # forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([ sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward') self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
# backward pass, all-reduce gradients and take an optimization step # backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([ grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
...@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output return logging_output
...@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
if self._sample is None: sample_size, logging_output, oom = 0, {}, False
return 0, {} if self._sample is not None:
try:
# calculate loss and sample size # calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom): def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None: if self.loss is not None:
# backward pass try:
self.loss.backward() # backward pass
self.loss.backward()
# get model parameters as a flattened (contiguous) tensor except RuntimeError as e:
flat_grads = self._flat_model_grads() if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
# all-reduce grads oom = True
nccl.all_reduce(flat_grads) if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# normalize grads # all-reduce grads and rescale by grad_denom
if grad_denom != 0: self._all_reduce_and_rescale_grads(grad_denom)
flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm) grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss # reset loss
self.loss = None self.loss = None
return grad_norm return grad_norm, oom
def _model_grads(self): def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
return [p.grad for p in self.model.parameters() if p.requires_grad] """All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
def _flat_model_grads(self): buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
grads = self._model_grads() buffer = []
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads) def all_reduce_buffer():
self._flat_grads = grads[0].data.new(num_params) # copy grads into buffer_t
offset = 0 offset = 0
for grad in grads: for g in buffer:
grad = grad.data.view(-1) numel = g.numel()
numel = grad.numel() buffer_t[offset:offset+numel].copy_(g.view(-1))
self._flat_grads[offset:offset+numel].copy_(grad) offset += numel
offset += numel # all-reduce and rescale
return self._flat_grads nccl.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
def _set_model_grads_(self, flat_grads): # copy all-reduced buffer back into grads
grads = self._model_grads() offset = 0
offset = 0 for g in buffer:
for grad in grads: numel = g.numel()
grad = grad.data.view(-1) g.view(-1).copy_(buffer_t[offset:offset+numel])
numel = grad.numel() offset += numel
grad.copy_(flat_grads[offset:offset+numel])
offset += numel filled = 0
assert offset == flat_grads.numel() for g in grads:
sz = g.numel() * g.element_size()
def _clip_grads_(self, flat_grads, clipv): if sz > buffer_size:
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors.""" # grad is bigger than buffer, all-reduce and rescale directly
norm = flat_grads.norm() nccl.all_reduce(g)
if clipv > 0 and norm > clipv: g.div_(grad_denom)
coef = max(norm, 1e-6) / clipv elif filled + sz > buffer_size:
flat_grads.div_(coef) # buffer is full, all-reduce and replace buffer with grad
return norm all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples): def valid_step(self, samples):
"""Do forward pass in parallel.""" """Do forward pass in parallel."""
...@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, volatile=True) self._scatter_samples(samples, volatile=True)
# forward pass # forward pass
_sample_sizes, logging_outputs = Future.gen_tuple_list([ _sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True) self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
assert sum(ooms_fwd) == 0
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
...@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if sample is None: if sample is None:
self._sample = None self._sample = None
else: else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id) self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment