Unverified Commit 94dae690 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #77 from facebookresearch/oss-merge-internal

parents d74f200a 0a836276
...@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module): ...@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
raise NotImplementedError raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
...@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module): ...@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
...@@ -43,6 +43,11 @@ class FairseqModel(nn.Module): ...@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.decoder.max_positions() return self.decoder.max_positions()
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
if self._is_generation_fast: if self._is_generation_fast:
......
...@@ -58,7 +58,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -58,7 +58,7 @@ class FConvEncoder(FairseqEncoder):
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions: for (out_channels, kernel_size) in convolutions:
pad = (kernel_size - 1) // 2 pad = (kernel_size - 1) / 2
self.projections.append(Linear(in_channels, out_channels) self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None) if in_channels != out_channels else None)
self.convolutions.append( self.convolutions.append(
...@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1):
super().__init__() super().__init__()
self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
...@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs. """Split and transpose encoder outputs.
...@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs ...@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std) m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_() m.bias.data.zero_()
return nn.utils.weight_norm(m) return nn.utils.weight_norm(m, dim=2)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
......
...@@ -59,7 +59,7 @@ class ConvTBCFunction(Function): ...@@ -59,7 +59,7 @@ class ConvTBCFunction(Function):
kernel_size = weight_size[0] kernel_size = weight_size[0]
output = input.new( output = input.new(
input_size[0] - kernel_size + 1 + pad * 2, input_size[0] - kernel_size + 1 + int(pad * 2),
input_size[1], input_size[1],
weight_size[2]) weight_size[2])
......
...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing. ...@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
""" """
from itertools import cycle, islice from itertools import cycle, islice
import math
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils from fairseq import meters, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG from fairseq.nag import NAG
...@@ -67,39 +68,61 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -67,39 +68,61 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
# initialize optimizer # initialize optimizer and LR scheduler
self.args.lr = list(map(float, self.args.lr.split(',')))
self.optimizer = self._build_optimizer() self.optimizer = self._build_optimizer()
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler() self.lr_scheduler = self._build_lr_scheduler()
self.loss = None
self._max_bsz_seen = 0
def _build_optimizer(self): def _build_optimizer(self):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self._override_optim_state = {}
if self.args.optimizer == 'adagrad': if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr, self._override_optim_state = {
weight_decay=self.args.weight_decay) 'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adagrad(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'adam': elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr, self._override_optim_state = {
betas=eval(self.args.adam_betas), 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag': elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr, self._override_optim_state = {
momentum=self.args.momentum, 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return NAG(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'sgd': elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr, self._override_optim_state = {
momentum=self.args.momentum, 'lr': self.args.lr[0],
weight_decay=self.args.weight_decay) 'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return torch.optim.SGD(self.model.parameters(), **self._override_optim_state)
else: else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer)) raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self): def _build_lr_scheduler(self):
if self.args.force_anneal > 0: if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr
def anneal(e): def anneal(e):
if e < self.args.force_anneal: if e < self.args.force_anneal:
return 1 # use fixed LR schedule
next_lr = lrs[min(e, len(lrs) - 1)]
else: else:
return self.args.lrshrink ** (e + 1 - self.args.force_anneal) next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal) lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None lr_scheduler.best = None
else: else:
...@@ -134,9 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -134,9 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return extra_state return extra_state
def _async_load_checkpoint(self, rank, device_id, filename): def _async_load_checkpoint(self, rank, device_id, filename):
extra_state, self._optim_history = utils.load_state( extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, self.criterion, self.optimizer, filename, self.model, cuda_device=device_id)
self.lr_scheduler, cuda_device=device_id)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self.lr_scheduler.best = last_optim['best_loss']
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self._override_optim_state)
return extra_state return extra_state
def set_seed(self, seed): def set_seed(self, seed):
...@@ -161,14 +199,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -161,14 +199,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass # forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([ sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward') self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
# backward pass, all-reduce gradients and take an optimization step # backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([ grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
...@@ -176,6 +214,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -176,6 +214,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output return logging_output
...@@ -186,34 +225,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -186,34 +225,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
if self._sample is None: sample_size, logging_output, oom = 0, {}, False
return 0, {} if self._sample is not None:
try:
# calculate loss and sample size # calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom): def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None: if self.loss is not None:
# backward pass try:
self.loss.backward() # backward pass
self.loss.backward()
# get model parameters as a flattened (contiguous) tensor except RuntimeError as e:
flat_grads = self._flat_model_grads() if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
# all-reduce grads oom = True
nccl.all_reduce(flat_grads) if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# normalize grads # all-reduce grads and rescale by grad_denom
if grad_denom != 0: self._all_reduce_and_rescale_grads(grad_denom)
flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm) grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -221,41 +270,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -221,41 +270,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss # reset loss
self.loss = None self.loss = None
return grad_norm return grad_norm, oom
def _model_grads(self): def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
return [p.grad for p in self.model.parameters() if p.requires_grad] """All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
def _flat_model_grads(self): buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
grads = self._model_grads() buffer = []
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads) def all_reduce_buffer():
self._flat_grads = grads[0].data.new(num_params) # copy grads into buffer_t
offset = 0 offset = 0
for grad in grads: for g in buffer:
grad = grad.data.view(-1) numel = g.numel()
numel = grad.numel() buffer_t[offset:offset+numel].copy_(g.view(-1))
self._flat_grads[offset:offset+numel].copy_(grad) offset += numel
offset += numel # all-reduce and rescale
return self._flat_grads nccl.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
def _set_model_grads_(self, flat_grads): # copy all-reduced buffer back into grads
grads = self._model_grads() offset = 0
offset = 0 for g in buffer:
for grad in grads: numel = g.numel()
grad = grad.data.view(-1) g.view(-1).copy_(buffer_t[offset:offset+numel])
numel = grad.numel() offset += numel
grad.copy_(flat_grads[offset:offset+numel])
offset += numel filled = 0
assert offset == flat_grads.numel() for g in grads:
sz = g.numel() * g.element_size()
def _clip_grads_(self, flat_grads, clipv): if sz > buffer_size:
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors.""" # grad is bigger than buffer, all-reduce and rescale directly
norm = flat_grads.norm() nccl.all_reduce(g)
if clipv > 0 and norm > clipv: g.div_(grad_denom)
coef = max(norm, 1e-6) / clipv elif filled + sz > buffer_size:
flat_grads.div_(coef) # buffer is full, all-reduce and replace buffer with grad
return norm all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples): def valid_step(self, samples):
"""Do forward pass in parallel.""" """Do forward pass in parallel."""
...@@ -263,10 +320,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -263,10 +320,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, volatile=True) self._scatter_samples(samples, volatile=True)
# forward pass # forward pass
_sample_sizes, logging_outputs = Future.gen_tuple_list([ _sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True) self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
assert sum(ooms_fwd) == 0
# aggregate logging output # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
...@@ -314,4 +372,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -314,4 +372,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if sample is None: if sample is None:
self._sample = None self._sample = None
else: else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id) self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
...@@ -49,8 +49,8 @@ def add_optimization_args(parser): ...@@ -49,8 +49,8 @@ def add_optimization_args(parser):
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=MultiprocessingTrainer.OPTIMIZERS, choices=MultiprocessingTrainer.OPTIMIZERS,
help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS))) help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS)))
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR', group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR1,LR2,...,LRn',
help='initial learning rate') help='learning rate for the first n epochs with all epochs >n using LRn')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float, group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N', group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N',
......
...@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_ ...@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None): def load_model_state(filename, model, cuda_device=None):
if not os.path.exists(filename): if not os.path.exists(filename):
return None, [] return None, [], None
if cuda_device is None: if cuda_device is None:
state = torch.load(filename) state = torch.load(filename)
else: else:
...@@ -94,18 +94,16 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device= ...@@ -94,18 +94,16 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device)) map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
) )
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters # load model parameters
model.load_state_dict(state['model']) try:
model.load_state_dict(state['model'])
# only load optimizer and lr_scheduler if they match with the checkpoint except:
optim_history = state['optimizer_history'] raise Exception('Cannot load model parameters from checkpoint, '
last_optim = optim_history[-1] 'please ensure that the architectures match')
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(state['last_optimizer_state'])
lr_scheduler.best = last_optim['best_loss']
return state['extra_state'], optim_history return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state): def _upgrade_state_dict(state):
...@@ -164,6 +162,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di ...@@ -164,6 +162,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble = [] ensemble = []
for state in states: for state in states:
model = build_model(args, src_dict, dst_dict) model = build_model(args, src_dict, dst_dict)
state['model'] = model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble, args return ensemble, args
......
...@@ -53,18 +53,18 @@ def main(): ...@@ -53,18 +53,18 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints # record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
args.num_gpus = torch.cuda.device_count()
print(args) print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits: for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format( print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences)) args.num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion # Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
...@@ -102,11 +102,11 @@ def main(): ...@@ -102,11 +102,11 @@ def main():
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch: while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch # train for one epoch
train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus) train(args, epoch, batch_offset, trainer, dataset, max_positions_train)
# evaluate on validate set # evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus) val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
if k == 0: if k == 0:
if not args.no_save: if not args.no_save:
# save checkpoint # save checkpoint
...@@ -130,7 +130,7 @@ def get_perplexity(loss): ...@@ -130,7 +130,7 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): def train(args, epoch, batch_offset, trainer, dataset, max_positions):
"""Train the model for one epoch.""" """Train the model for one epoch."""
seed = args.seed + epoch seed = args.seed + epoch
...@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
lr = trainer.get_lr() lr = trainer.get_lr()
with utils.build_progress_bar(args, itr, epoch) as t: with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
loss_dict = trainer.train_step(sample) loss_dict = trainer.train_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
...@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): ...@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state) trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
...@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): ...@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
prefix = 'valid on \'{}\' subset'.format(subset) prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t: with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus): for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample) loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment