Commit e88ad84b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Use --train-subset and --valid-subset properly

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/614

Differential Revision: D14712321

Pulled By: myleott

fbshipit-source-id: 8ef973c5d30ebccf0df0f1cabdddd590248a8f8d
parent 3efc39ee
...@@ -39,7 +39,7 @@ def main(args, init_distributed=False): ...@@ -39,7 +39,7 @@ def main(args, init_distributed=False):
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load dataset splits # Load dataset splits
load_dataset_splits(task, ['train', 'valid']) load_dataset_splits(args, task)
# Initialize distributed training (after data loading) # Initialize distributed training (after data loading)
if init_distributed: if init_distributed:
...@@ -64,8 +64,8 @@ def main(args, init_distributed=False): ...@@ -64,8 +64,8 @@ def main(args, init_distributed=False):
task.max_positions(), task.max_positions(),
model.max_positions(), model.max_positions(),
) )
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions)
# Build trainer # Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
...@@ -358,19 +358,17 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -358,19 +358,17 @@ def load_checkpoint(args, trainer, epoch_itr):
return False return False
def load_dataset_splits(task, splits): def load_dataset_splits(args, task):
for split in splits: task.load_dataset(args.train_subset, combine=True)
if split == 'train': for split in args.valid_subset.split(','):
task.load_dataset(split, combine=True) for k in itertools.count():
else: split_k = split + (str(k) if k > 0 else '')
for k in itertools.count(): try:
split_k = split + (str(k) if k > 0 else '') task.load_dataset(split_k, combine=False)
try: except FileNotFoundError as e:
task.load_dataset(split_k, combine=False) if k > 0:
except FileNotFoundError as e: break
if k > 0: raise e
break
raise e
def distributed_main(i, args): def distributed_main(i, args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment