Commit 4fa8760e authored by Myle Ott's avatar Myle Ott
Browse files

Improve dataloader speed and deprecate concept of batch_offset (use...

Improve dataloader speed and deprecate concept of batch_offset (use --sample-without-replacement instead)
parent c52f6ea4
......@@ -134,21 +134,40 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False, shard_id=0, num_shards=1):
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
max_sentences=max_sentences, max_positions=max_positions)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
......@@ -358,11 +377,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
ignore_invalid_inputs, allow_different_src_lens=False))
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
epoch=1, sample=0, max_positions=(1024, 1024),
sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_positions=(1024, 1024)):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
......@@ -378,26 +395,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True))
if not sort_by_source_size:
np.random.shuffle(batches)
if sample:
offset = (epoch - 1) * sample
while offset > len(batches):
np.random.shuffle(batches)
offset -= len(batches)
result = batches[offset:(offset + sample)]
while len(result) < sample:
np.random.shuffle(batches)
result += batches[:(sample - len(result))]
assert len(result) == sample, \
"batch length is not correct {}".format(len(result))
batches = result
return batches
......
......@@ -108,6 +108,10 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
......@@ -170,12 +174,6 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--update-freq', default=1, type=int, metavar='N',
help='update parameters every N batches')
return group
......@@ -187,10 +185,10 @@ def add_checkpoint_args(parser):
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=-1, metavar='N',
help='save a checkpoint every N updates')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--no-save', action='store_true',
help='don\'t save models and checkpoints')
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
......
......@@ -55,19 +55,34 @@ def main(args):
args.max_sentences,
))
# Initialize dataloader
train_dataloader = dataset.train_dataloader_generator(
args.train_subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=(
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
),
seed=args.seed,
sample_without_replacement=args.sample_without_replacement,
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
# Load the latest checkpoint if one is available
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch = 1
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
for i in range(epoch):
_ = next(train_dataloader)
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
......@@ -77,24 +92,24 @@ def main(args):
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, trainer, dataset, epoch, batch_offset)
train(args, trainer, next(train_dataloader), epoch)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(epoch, val_loss)
first_val_loss = val_loss
# save checkpoint
if not args.no_save:
save_checkpoint(trainer, args, epoch, 0, val_loss)
else:
lr = trainer.lr_step(epoch)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint
if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss)
epoch += 1
batch_offset = 0
if trainer.get_num_updates() >= max_update:
break
......@@ -103,7 +118,7 @@ def main(args):
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, dataset, epoch, batch_offset):
def train(args, trainer, itr, epoch):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
......@@ -111,30 +126,6 @@ def train(args, trainer, dataset, epoch, batch_offset):
seed = args.seed + epoch
torch.manual_seed(seed)
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
)
# Initialize dataloader, starting at batch_offset
itr = dataset.train_dataloader(
args.train_subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions_train,
seed=seed,
epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum),
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
epoch_size = len(itr)
itr = itertools.islice(progress, batch_offset, None)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
......@@ -143,8 +134,10 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset):
if i < epoch_size - 1 and (i + 1) % args.update_freq > 0:
num_batches = len(itr)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
for i, sample in enumerate(progress):
if i < num_batches - 1 and (i + 1) % args.update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
continue
......@@ -164,15 +157,10 @@ def train(args, trainer, dataset, epoch, batch_offset):
progress.log(stats)
# ignore the first mini-batch in words-per-second calculation
if i == batch_offset:
if i == 0:
trainer.get_meter('wps').reset()
# save mid-epoch checkpoints
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
if trainer.get_num_updates() >= max_update:
break
# log end-of-epoch stats
......@@ -274,28 +262,22 @@ def get_perplexity(loss):
return float('inf')
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss=None):
def save_checkpoint(trainer, args, epoch, val_loss=None):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
elif not args.no_epoch_checkpoints:
epoch_filename = os.path.join(
args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment