Commit 4fa8760e authored by Myle Ott's avatar Myle Ott
Browse files

Improve dataloader speed and deprecate concept of batch_offset (use...

Improve dataloader speed and deprecate concept of batch_offset (use --sample-without-replacement instead)
parent c52f6ea4
...@@ -134,21 +134,40 @@ class LanguageDatasets(object): ...@@ -134,21 +134,40 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos() assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk() assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, max_tokens=None, def train_dataloader_generator(
max_sentences=None, max_positions=(1024, 1024), self, split, max_tokens=None, max_sentences=None,
seed=None, epoch=1, sample_without_replacement=0, max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
sort_by_source_size=False, shard_id=0, num_shards=1): shard_id=0, num_shards=1
):
dataset = self.splits[split] dataset = self.splits[split]
with numpy_seed(seed): with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size( batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens, dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, epoch=epoch, max_sentences=max_sentences, max_positions=max_positions)
sample=sample_without_replacement, max_positions=max_positions, frozen_batches = tuple(batches) # freeze
sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards) def dataloader(b):
return torch.utils.data.DataLoader( b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
dataset, collate_fn=dataset.collater, return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
batch_sampler=batch_sampler)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None, def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024), max_sentences=None, max_positions=(1024, 1024),
...@@ -358,11 +377,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -358,11 +377,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
ignore_invalid_inputs, allow_different_src_lens=False)) ignore_invalid_inputs, allow_different_src_lens=False))
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None, def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_positions=(1024, 1024)):
epoch=1, sample=0, max_positions=(1024, 1024), """Returns batches of indices bucketed by size. Batches may contain
sort_by_source_size=False): sequences of different lengths."""
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
...@@ -378,26 +395,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -378,26 +395,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
batches = list(_make_batches( batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True)) ignore_invalid_inputs=True, allow_different_src_lens=True))
if not sort_by_source_size:
np.random.shuffle(batches)
if sample:
offset = (epoch - 1) * sample
while offset > len(batches):
np.random.shuffle(batches)
offset -= len(batches)
result = batches[offset:(offset + sample)]
while len(result) < sample:
np.random.shuffle(batches)
result += batches[:(sample - len(result))]
assert len(result) == sample, \
"batch length is not correct {}".format(len(result))
batches = result
return batches return batches
......
...@@ -108,6 +108,10 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -108,6 +108,10 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
if gen: if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT', group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)') help='data subset to generate (train, valid, test)')
...@@ -170,12 +174,6 @@ def add_optimization_args(parser): ...@@ -170,12 +174,6 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--update-freq', default=1, type=int, metavar='N', group.add_argument('--update-freq', default=1, type=int, metavar='N',
help='update parameters every N batches') help='update parameters every N batches')
return group return group
...@@ -187,10 +185,10 @@ def add_checkpoint_args(parser): ...@@ -187,10 +185,10 @@ def add_checkpoint_args(parser):
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=-1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N updates') help='save a checkpoint every N epochs')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models and checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N', group.add_argument('--validate-interval', type=int, default=1, metavar='N',
......
...@@ -55,19 +55,34 @@ def main(args): ...@@ -55,19 +55,34 @@ def main(args):
args.max_sentences, args.max_sentences,
)) ))
# Initialize dataloader
train_dataloader = dataset.train_dataloader_generator(
args.train_subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=(
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
),
seed=args.seed,
sample_without_replacement=args.sample_without_replacement,
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path) epoch = 1
if extra_state is not None: if os.path.isfile(checkpoint_path):
epoch = extra_state['epoch'] extra_state = trainer.load_checkpoint(checkpoint_path)
batch_offset = extra_state['batch_offset'] if extra_state is not None:
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch)) epoch = extra_state['epoch']
if batch_offset == 0: print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch) trainer.lr_step(epoch)
for i in range(epoch):
_ = next(train_dataloader)
epoch += 1 epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -77,24 +92,24 @@ def main(args): ...@@ -77,24 +92,24 @@ def main(args):
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch: while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch # train for one epoch
train(args, trainer, dataset, epoch, batch_offset) train(args, trainer, next(train_dataloader), epoch)
# evaluate on validate set # evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0: if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch) val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0: if k == 0:
# only use first validation loss to update the learning schedule first_val_loss = val_loss
lr = trainer.lr_step(epoch, val_loss)
# save checkpoint # only use first validation loss to update the learning rate
if not args.no_save: lr = trainer.lr_step(epoch, first_val_loss)
save_checkpoint(trainer, args, epoch, 0, val_loss)
else: # save checkpoint
lr = trainer.lr_step(epoch) if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss)
epoch += 1 epoch += 1
batch_offset = 0
if trainer.get_num_updates() >= max_update: if trainer.get_num_updates() >= max_update:
break break
...@@ -103,7 +118,7 @@ def main(args): ...@@ -103,7 +118,7 @@ def main(args):
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, dataset, epoch, batch_offset): def train(args, trainer, itr, epoch):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get # Set seed based on args.seed and the epoch number so that we get
...@@ -111,30 +126,6 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -111,30 +126,6 @@ def train(args, trainer, dataset, epoch, batch_offset):
seed = args.seed + epoch seed = args.seed + epoch
torch.manual_seed(seed) torch.manual_seed(seed)
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
)
# Initialize dataloader, starting at batch_offset
itr = dataset.train_dataloader(
args.train_subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions_train,
seed=seed,
epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum),
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
epoch_size = len(itr)
itr = itertools.islice(progress, batch_offset, None)
# reset training meters # reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']: for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k) meter = trainer.get_meter(k)
...@@ -143,8 +134,10 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -143,8 +134,10 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset): num_batches = len(itr)
if i < epoch_size - 1 and (i + 1) % args.update_freq > 0: progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
for i, sample in enumerate(progress):
if i < num_batches - 1 and (i + 1) % args.update_freq > 0:
# buffer updates according to --update-freq # buffer updates according to --update-freq
trainer.train_step(sample, update_params=False) trainer.train_step(sample, update_params=False)
continue continue
...@@ -164,15 +157,10 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -164,15 +157,10 @@ def train(args, trainer, dataset, epoch, batch_offset):
progress.log(stats) progress.log(stats)
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
if i == batch_offset: if i == 0:
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
# save mid-epoch checkpoints if trainer.get_num_updates() >= max_update:
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
break break
# log end-of-epoch stats # log end-of-epoch stats
...@@ -274,28 +262,22 @@ def get_perplexity(loss): ...@@ -274,28 +262,22 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss=None): def save_checkpoint(trainer, args, epoch, val_loss=None):
extra_state = { extra_state = {
'epoch': epoch, 'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss, 'val_loss': val_loss,
} }
if batch_offset == 0: if not args.no_epoch_checkpoints:
if not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
elif not args.no_epoch_checkpoints:
epoch_filename = os.path.join(
args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
trainer.save_checkpoint(epoch_filename, extra_state) trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt') last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state) trainer.save_checkpoint(last_filename, extra_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment