"vscode:/vscode.git/clone" did not exist on "1bbf8bdcebd762e9ee60ea8123cde57500867f40"
Commit e88ad84b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Use --train-subset and --valid-subset properly

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/614

Differential Revision: D14712321

Pulled By: myleott

fbshipit-source-id: 8ef973c5d30ebccf0df0f1cabdddd590248a8f8d
parent 3efc39ee
......@@ -39,7 +39,7 @@ def main(args, init_distributed=False):
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(task, ['train', 'valid'])
load_dataset_splits(args, task)
# Initialize distributed training (after data loading)
if init_distributed:
......@@ -64,8 +64,8 @@ def main(args, init_distributed=False):
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions)
# Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
......@@ -358,11 +358,9 @@ def load_checkpoint(args, trainer, epoch_itr):
return False
def load_dataset_splits(task, splits):
for split in splits:
if split == 'train':
task.load_dataset(split, combine=True)
else:
def load_dataset_splits(args, task):
task.load_dataset(args.train_subset, combine=True)
for split in args.valid_subset.split(','):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment