Commit ff68a9ef authored by Myle Ott's avatar Myle Ott
Browse files

Add FairseqTask

A Task defines the data format, stores shared state (e.g., dictionaries) and provides helpers for building the model/criterion and calculating the loss.

Changes:
- Add TranslationTask and LanguageModelingTask. New tasks can be registered with @register_task decorator.
- Add EpochBatchIterator to encapsulate batching and saving/restoring dataloader position
- Remove LEFT_PAD_* constants and make them configurable per task
parent 2de93532
...@@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase):
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d) self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self): def test_nll_loss(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d) nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
...@@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase):
def test_padding(self): def test_padding(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample) loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx): def get_one_no_padding(idx):
...@@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase):
sample1 = next(test_utils.dummy_dataloader([self.data[idx]])) sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args) args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0) args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d) model1 = self.task.build_model(args1)
loss1, _, _ = crit(model1, sample1) loss1, _, _ = crit(model1, sample1)
return loss1 return loss1
...@@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase):
def test_reduction(self): def test_reduction(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample, reduce=True) loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False) unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum()) self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self): def test_zero_eps(self):
self.args.label_smoothing = 0.0 self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d) nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss) self.assertAlmostEqual(nll_loss, smooth_loss)
......
...@@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase):
]), ]),
] ]
self.model = test_utils.TestModel.build_model(args, d, d) task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model]) generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self): def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case # Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], normalize_scores=False) generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self): def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self): def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], maxlen=2) generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01]) self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], stop_early=False) generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
......
...@@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase): ...@@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase):
[0.3, 0.7], # sentence 3 [0.3, 0.7], # sentence 3
] ]
model = test_utils.TestModel.build_model(args, d, d) task = test_utils.TestTranslationTask.setup_task(args, d, d)
scorer = SequenceScorer([model]) model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr): for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target']) self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id]) self.assertHypoScore(hypos[0], expected_scores[id])
......
...@@ -8,23 +8,45 @@ ...@@ -8,23 +8,45 @@
import contextlib import contextlib
from io import StringIO from io import StringIO
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch
from fairseq import data
import train import train
def mock_trainer(epoch, num_updates, end_of_epoch): def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock() trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch} trainer.load_checkpoint.return_value = {
'train_iterator': {
'epoch': epoch,
'iterations_in_epoch': iterations_in_epoch,
'shuffle': False,
},
}
trainer.get_num_updates.return_value = num_updates trainer.get_num_updates.return_value = num_updates
return trainer return trainer
def mock_loader(length): def mock_dict():
loader = MagicMock() d = MagicMock()
loader.__next__.return_value = list(range(length)) d.pad.return_value = 1
return loader d.eos.return_value = 2
d.unk.return_value = 3
return d
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1,
)
return trainer, epoch_itr
class TestLoadCheckpoint(unittest.TestCase): class TestLoadCheckpoint(unittest.TestCase):
...@@ -40,29 +62,41 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -40,29 +62,41 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 200, False) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch, 2) self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(next(ds), 50) self.assertEqual(epoch_itr.iterations_in_epoch, 50)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 300, True) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch, 3) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(next(iter(ds)), 0)
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def test_load_no_checkpoint(self): def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(0, 0, False) trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch, 1) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(next(iter(ds)), 0)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
......
...@@ -8,18 +8,19 @@ ...@@ -8,18 +8,19 @@
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq.data.language_pair_dataset import collate
from fairseq import utils from fairseq import utils
from fairseq.data import dictionary from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqModel,
) )
from fairseq.tasks import FairseqTask
def dummy_dictionary(vocab_size, prefix='token_'): def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary() d = Dictionary()
for i in range(vocab_size): for i in range(vocab_size):
token = prefix + str(i) token = prefix + str(i)
d.add_symbol(token) d.add_symbol(token)
...@@ -46,14 +47,7 @@ def dummy_dataloader( ...@@ -46,14 +47,7 @@ def dummy_dataloader(
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
collate_fn=( collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
lambda samples: collate(
samples,
padding_idx,
eos_idx,
has_target=True,
)
),
) )
return iter(dataloader) return iter(dataloader)
...@@ -71,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset): ...@@ -71,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset):
return len(self.data) return len(self.data)
class TestTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.model = model
@classmethod
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
return cls(args, src_dict, tgt_dict, model)
def build_model(self, args):
return TestModel.build_model(args, self)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
class TestModel(FairseqModel): class TestModel(FairseqModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
encoder = TestEncoder(args, src_dict) encoder = TestEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, dst_dict) decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder) return cls(encoder, decoder)
......
...@@ -7,14 +7,12 @@ ...@@ -7,14 +7,12 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import collections import collections
import itertools
import os import os
import math import math
import torch import torch
from itertools import islice from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
from fairseq import criterions, models, options, progress_bar, utils
from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -23,7 +21,6 @@ from fairseq.meters import AverageMeter, StopwatchMeter ...@@ -23,7 +21,6 @@ from fairseq.meters import AverageMeter, StopwatchMeter
def main(args): def main(args):
if args.max_tokens is None: if args.max_tokens is None:
args.max_tokens = 6000 args.max_tokens = 6000
print(args) print(args)
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -31,27 +28,25 @@ def main(args): ...@@ -31,27 +28,25 @@ def main(args):
torch.cuda.set_device(args.device_id) torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# Load dataset # Setup task, e.g., translation, language modeling, etc.
splits = ['train', 'valid'] task = tasks.setup_task(args)
dataset = load_dataset(args, splits)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
model = models.build_model(args, dataset.src_dict, dataset.dst_dict) # Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict) # Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters()))) print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Build trainer # Build trainer
if args.fp16: if args.fp16:
trainer = FP16Trainer(args, model, criterion) trainer = FP16Trainer(args, task, model, criterion)
else: else:
if torch.cuda.get_device_capability(0)[0] >= 7: if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16') print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, model, criterion) trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size)) print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens, args.max_tokens,
...@@ -59,25 +54,24 @@ def main(args): ...@@ -59,25 +54,24 @@ def main(args):
)) ))
# Initialize dataloader # Initialize dataloader
train_dataloader = dataset.train_dataloader_generator( max_positions = trainer.get_model().max_positions()
args.train_subset, epoch_itr = data.EpochBatchIterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences_valid,
max_positions=( max_positions=max_positions,
min(args.max_source_positions, trainer.get_model().max_encoder_positions()), ignore_invalid_inputs=True,
min(args.max_target_positions, trainer.get_model().max_decoder_positions()) required_batch_size_multiple=8,
),
seed=args.seed, seed=args.seed,
sample_without_replacement=args.sample_without_replacement,
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
epoch, next_ds = load_checkpoint(args, trainer, train_dataloader) load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator # Send a dummy batch to warm the caching allocator
dummy_batch = data_utils.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict) dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch) trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
...@@ -88,58 +82,41 @@ def main(args): ...@@ -88,58 +82,41 @@ def main(args):
train_meter.start() train_meter.start()
valid_losses = [None] valid_losses = [None]
valid_subsets = args.valid_subset.split(',') valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
train(args, trainer, next_ds, epoch, dataset) train(args, trainer, task, epoch_itr)
if epoch % args.validate_interval == 0: if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, dataset, valid_subsets, epoch) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate # only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, valid_losses[0]) lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint # save checkpoint
if epoch % args.save_interval == 0: if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=valid_losses[0]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
epoch += 1
next_ds = next(train_dataloader)
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
def load_dataset(args, splits): def train(args, trainer, task, epoch_itr):
is_raw = not data_utils.has_binary_files(args.data, splits)
dataset = data_loaders.load_dataset(args, splits, is_raw)
return dataset
def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get # Initialize data iterator
# reproducible results when resuming from checkpoints itr = epoch_itr.next_epoch_itr()
seed = args.seed + epoch progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
torch.manual_seed(seed)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
# update parameters every N batches # update parameters every N batches
if epoch <= len(args.update_freq): if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch - 1] update_freq = args.update_freq[epoch_itr.epoch - 1]
else: else:
update_freq = args.update_freq[-1] update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0] first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
num_batches = len(itr) num_batches = len(epoch_itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
for i, sample in enumerate(progress):
if i < num_batches - 1 and (i + 1) % update_freq > 0: if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq # buffer updates according to --update-freq
trainer.train_step(sample, update_params=False) trainer.train_step(sample, update_params=False)
...@@ -165,8 +142,8 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -165,8 +142,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
valid_losses = validate(args, trainer, dataset, [first_valid], epoch) valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=valid_losses[0]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update: if num_updates >= max_update:
break break
...@@ -177,6 +154,12 @@ def train(args, trainer, itr, epoch, dataset): ...@@ -177,6 +154,12 @@ def train(args, trainer, itr, epoch, dataset):
stats[k] = meter.avg stats[k] = meter.avg
progress.print(stats) progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer): def get_training_stats(trainer):
stats = collections.OrderedDict() stats = collections.OrderedDict()
...@@ -202,27 +185,24 @@ def get_training_stats(trainer): ...@@ -202,27 +185,24 @@ def get_training_stats(trainer):
return stats return stats
def validate(args, trainer, dataset, subsets, epoch): def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses.""" """Evaluate the model on the validation set(s) and return the losses."""
valid_losses = [] valid_losses = []
for subset in subsets: for subset in subsets:
# Initialize dataloader # Initialize data iterator
max_positions_valid = ( itr = data.EpochBatchIterator(
trainer.get_model().max_encoder_positions(), dataset=task.dataset(subset),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid, max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid, max_positions=trainer.get_model().max_positions(),
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator required_batch_size_multiple=8,
shard_id=args.distributed_rank, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
) shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar( progress = progress_bar.build_progress_bar(
args, itr, epoch, args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset), prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple' no_progress_bar='simple'
) )
...@@ -232,8 +212,8 @@ def validate(args, trainer, dataset, subsets, epoch): ...@@ -232,8 +212,8 @@ def validate(args, trainer, dataset, subsets, epoch):
meter = trainer.get_meter(k) meter = trainer.get_meter(k)
if meter is not None: if meter is not None:
meter.reset() meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress: for sample in progress:
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
...@@ -274,9 +254,11 @@ def get_perplexity(loss): ...@@ -274,9 +254,11 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or args.distributed_rank > 0: if args.no_save or not distributed_utils.is_master(args):
return return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates() updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict() checkpoint_conds = collections.OrderedDict()
...@@ -298,11 +280,9 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): ...@@ -298,11 +280,9 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if val_loss is not None: if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best) save_checkpoint.best = min(val_loss, prev_best)
extra_state = { extra_state = {
'best': prev_best, 'best': save_checkpoint.best,
'end_of_epoch': end_of_epoch, 'train_iterator': epoch_itr.state_dict(),
'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
} }
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
...@@ -325,46 +305,36 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): ...@@ -325,46 +305,36 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
os.remove(old_chk) os.remove(old_chk)
def load_checkpoint(args, trainer, train_dataloader): def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
epoch = 1
ds = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path) extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None: if extra_state is not None:
epoch = extra_state['epoch'] # replay train iterator to match checkpoint
end_of_epoch = extra_state.get('end_of_epoch', True) epoch_itr.load_state_dict(extra_state['train_iterator'])
trainer_updates = trainer.get_num_updates()
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state: if 'best' in extra_state:
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state['best']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
updates = 0
for i in range(epoch):
ds = next(train_dataloader)
updates += len(ds)
if not end_of_epoch and ds is not None and updates > trainer_updates:
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
print('| resuming from batch {}'.format(completed_batches + 1)) def load_dataset_splits(args, task, splits):
for split in splits:
# consume completed batches for k in itertools.count():
next(islice(ds, completed_batches, completed_batches), None) split_k = split + (str(k) if k > 0 else '')
else: try:
if not end_of_epoch: task.load_dataset(split_k)
print('| WARNING: checkpoint is not at end of epoch') print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
ds = next(train_dataloader) except FileNotFoundError as e:
epoch += 1 if k > 0:
break
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0)) raise e
return epoch, ds or next(train_dataloader)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment