Commit a233fceb authored by Myle Ott's avatar Myle Ott
Browse files

Improve memory handling (recover from OOM and periodically empty caching allocator)

parent be274623
...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing. ...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
""" """
from itertools import cycle, islice from itertools import cycle, islice
import math
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils from fairseq import meters, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG from fairseq.nag import NAG
...@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -74,6 +75,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize LR scheduler # initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler() self.lr_scheduler = self._build_lr_scheduler()
self._max_bsz_seen = 0
def _build_optimizer(self): def _build_optimizer(self):
if self.args.optimizer == 'adagrad': if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr, return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
...@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -161,14 +164,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass # forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([ sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward') self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
# backward pass, all-reduce gradients and take an optimization step # backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([ grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
...@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -176,6 +179,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output return logging_output
...@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -186,34 +190,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
if self._sample is None: sample_size, logging_output, oom = 0, {}, False
return 0, {} if self._sample is not None:
try:
# calculate loss and sample size # calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom): def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None: if self.loss is not None:
try:
# backward pass # backward pass
self.loss.backward() self.loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# get model parameters as a flattened (contiguous) tensor # all-reduce grads and rescale by grad_denom
flat_grads = self._flat_model_grads() self._all_reduce_and_rescale_grads(grad_denom)
# all-reduce grads
nccl.all_reduce(flat_grads)
# normalize grads
if grad_denom != 0:
flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm) grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -221,41 +235,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss # reset loss
self.loss = None self.loss = None
return grad_norm return grad_norm, oom
def _model_grads(self): def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
return [p.grad for p in self.model.parameters() if p.requires_grad] """All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
buffer = []
def _flat_model_grads(self): def all_reduce_buffer():
grads = self._model_grads() # copy grads into buffer_t
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0 offset = 0
for grad in grads: for g in buffer:
grad = grad.data.view(-1) numel = g.numel()
numel = grad.numel() buffer_t[offset:offset+numel].copy_(g.view(-1))
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel offset += numel
return self._flat_grads # all-reduce and rescale
nccl.all_reduce(buffer_t[:offset])
def _set_model_grads_(self, flat_grads): buffer_t.div_(grad_denom)
grads = self._model_grads() # copy all-reduced buffer back into grads
offset = 0 offset = 0
for grad in grads: for g in buffer:
grad = grad.data.view(-1) numel = g.numel()
numel = grad.numel() g.view(-1).copy_(buffer_t[offset:offset+numel])
grad.copy_(flat_grads[offset:offset+numel])
offset += numel offset += numel
assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv): filled = 0
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors.""" for g in grads:
norm = flat_grads.norm() sz = g.numel() * g.element_size()
if clipv > 0 and norm > clipv: if sz > buffer_size:
coef = max(norm, 1e-6) / clipv # grad is bigger than buffer, all-reduce and rescale directly
flat_grads.div_(coef) nccl.all_reduce(g)
return norm g.div_(grad_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples): def valid_step(self, samples):
"""Do forward pass in parallel.""" """Do forward pass in parallel."""
...@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -263,10 +285,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, volatile=True) self._scatter_samples(samples, volatile=True)
# forward pass # forward pass
_sample_sizes, logging_outputs = Future.gen_tuple_list([ _sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True) self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
assert sum(ooms_fwd) == 0
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
...@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -314,4 +337,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if sample is None: if sample is None:
self._sample = None self._sample = None
else: else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id) self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment