Commit ff68a9ef authored by Myle Ott's avatar Myle Ott
Browse files

Add FairseqTask

A Task defines the data format, stores shared state (e.g., dictionaries) and provides helpers for building the model/criterion and calculating the loss.

Changes:
- Add TranslationTask and LanguageModelingTask. New tasks can be registered with @register_task decorator.
- Add EpochBatchIterator to encapsulate batching and saving/restoring dataloader position
- Remove LEFT_PAD_* constants and make them configurable per task
parent 2de93532
......@@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase):
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d)
self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
......@@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase):
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
......@@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase):
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d)
model1 = self.task.build_model(args1)
loss1, _, _ = crit(model1, sample1)
return loss1
......@@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase):
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)
......
......@@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase):
]),
]
self.model = test_utils.TestModel.build_model(args, d, d)
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self):
generator = SequenceGenerator([self.model])
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......@@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], normalize_scores=False)
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......@@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], len_penalty=lenpen)
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......@@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], len_penalty=lenpen)
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......@@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator([self.model], maxlen=2)
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......@@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], stop_early=False)
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
......
......@@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase):
[0.3, 0.7], # sentence 3
]
model = test_utils.TestModel.build_model(args, d, d)
scorer = SequenceScorer([model])
task = test_utils.TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id])
......
......@@ -8,23 +8,45 @@
import contextlib
from io import StringIO
import unittest
from unittest.mock import MagicMock, patch
import torch
from fairseq import data
import train
def mock_trainer(epoch, num_updates, end_of_epoch):
def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
trainer.load_checkpoint.return_value = {
'train_iterator': {
'epoch': epoch,
'iterations_in_epoch': iterations_in_epoch,
'shuffle': False,
},
}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_loader(length):
loader = MagicMock()
loader.__next__.return_value = list(range(length))
return loader
def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1,
)
return trainer, epoch_itr
class TestLoadCheckpoint(unittest.TestCase):
......@@ -40,29 +62,41 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 200, False)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(next(ds), 50)
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 300, True)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 3)
self.assertEqual(next(iter(ds)), 0)
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(0, 0, False)
loader = mock_loader(150)
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(next(iter(ds)), 0)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def tearDown(self):
patch.stopall()
......
......@@ -8,18 +8,19 @@
import torch
from torch.autograd import Variable
from fairseq.data.language_pair_dataset import collate
from fairseq import utils
from fairseq.data import dictionary
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
)
from fairseq.tasks import FairseqTask
def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary()
d = Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
d.add_symbol(token)
......@@ -46,14 +47,7 @@ def dummy_dataloader(
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=(
lambda samples: collate(
samples,
padding_idx,
eos_idx,
has_target=True,
)
),
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
)
return iter(dataloader)
......@@ -71,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset):
return len(self.data)
class TestTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.model = model
@classmethod
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
return cls(args, src_dict, tgt_dict, model)
def build_model(self, args):
return TestModel.build_model(args, self)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
class TestModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, args, src_dict, dst_dict):
encoder = TestEncoder(args, src_dict)
decoder = TestIncrementalDecoder(args, dst_dict)
def build_model(cls, args, task):
encoder = TestEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder)
......
......@@ -7,14 +7,12 @@
# can be found in the PATENTS file in the same directory.
import collections
import itertools
import os
import math
import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar, utils
from fairseq.data import data_utils, data_loaders
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
......@@ -23,7 +21,6 @@ from fairseq.meters import AverageMeter, StopwatchMeter
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
......@@ -31,27 +28,25 @@ def main(args):
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Load dataset
splits = ['train', 'valid']
dataset = load_dataset(args, splits)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
# Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Build trainer
if args.fp16:
trainer = FP16Trainer(args, model, criterion)
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, model, criterion)
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
......@@ -59,25 +54,24 @@ def main(args):
))
# Initialize dataloader
train_dataloader = dataset.train_dataloader_generator(
args.train_subset,
max_positions = trainer.get_model().max_positions()
epoch_itr = data.EpochBatchIterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=(
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
),
max_sentences=args.max_sentences_valid,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
sample_without_replacement=args.sample_without_replacement,
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
epoch, next_ds = load_checkpoint(args, trainer, train_dataloader)
load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator
dummy_batch = data_utils.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
......@@ -88,58 +82,41 @@ def main(args):
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next_ds, epoch, dataset)
train(args, trainer, task, epoch_itr)
if epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, dataset, valid_subsets, epoch)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, valid_losses[0])
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch, end_of_epoch=True, val_loss=valid_losses[0])
epoch += 1
next_ds = next(train_dataloader)
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def load_dataset(args, splits):
is_raw = not data_utils.has_binary_files(args.data, splits)
dataset = data_loaders.load_dataset(args, splits, is_raw)
return dataset
def train(args, trainer, itr, epoch, dataset):
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
# reproducible results when resuming from checkpoints
seed = args.seed + epoch
torch.manual_seed(seed)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch - 1]
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
for i, sample in enumerate(progress):
num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
......@@ -165,8 +142,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
valid_losses = validate(args, trainer, dataset, [first_valid], epoch)
save_checkpoint(args, trainer, epoch, end_of_epoch=False, val_loss=valid_losses[0])
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
......@@ -177,6 +154,12 @@ def train(args, trainer, itr, epoch, dataset):
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
......@@ -202,27 +185,24 @@ def get_training_stats(trainer):
return stats
def validate(args, trainer, dataset, subsets, epoch):
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize dataloader
max_positions_valid = (
trainer.get_model().max_encoder_positions(),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
# Initialize data iterator
itr = data.EpochBatchIterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank,
max_positions=trainer.get_model().max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
)
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch,
args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
......@@ -232,8 +212,8 @@ def validate(args, trainer, dataset, subsets, epoch):
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
......@@ -274,9 +254,11 @@ def get_perplexity(loss):
return float('inf')
def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if args.no_save or args.distributed_rank > 0:
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
......@@ -298,11 +280,9 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'best': prev_best,
'end_of_epoch': end_of_epoch,
'epoch': epoch,
'best': save_checkpoint.best,
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
......@@ -325,46 +305,36 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
os.remove(old_chk)
def load_checkpoint(args, trainer, train_dataloader):
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
epoch = 1
ds = None
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
end_of_epoch = extra_state.get('end_of_epoch', True)
trainer_updates = trainer.get_num_updates()
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
updates = 0
for i in range(epoch):
ds = next(train_dataloader)
updates += len(ds)
if not end_of_epoch and ds is not None and updates > trainer_updates:
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
print('| resuming from batch {}'.format(completed_batches + 1))
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else:
if not end_of_epoch:
print('| WARNING: checkpoint is not at end of epoch')
ds = next(train_dataloader)
epoch += 1
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0))
return epoch, ds or next(train_dataloader)
def load_dataset_splits(args, task, splits):
for split in splits:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment