Commit 3b2cecda authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

1) replaced fstring 2) fixed error from max-positions arg

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/787

Differential Revision: D16562052

fbshipit-source-id: 640e30b2378ec917d60092558d3088a77f9741cb
parent e75cff5f
...@@ -41,8 +41,6 @@ class SentencePredictionTask(FairseqTask): ...@@ -41,8 +41,6 @@ class SentencePredictionTask(FairseqTask):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='FILE', parser.add_argument('data', metavar='FILE',
help='file prefix for data') help='file prefix for data')
parser.add_argument('--max-positions', type=int, default=512,
help='max input length')
parser.add_argument('--num-classes', type=int, default=-1, parser.add_argument('--num-classes', type=int, default=-1,
help='number of classes') help='number of classes')
parser.add_argument('--init-token', type=int, default=None, parser.add_argument('--init-token', type=int, default=None,
...@@ -160,7 +158,7 @@ class SentencePredictionTask(FairseqTask): ...@@ -160,7 +158,7 @@ class SentencePredictionTask(FairseqTask):
) )
) )
else: else:
label_path = f"{get_path('label', split)}.label" label_path = "{0}.label".format(get_path('label', split))
if os.path.exists(label_path): if os.path.exists(label_path):
dataset.update( dataset.update(
target=RawLabelDataset([ target=RawLabelDataset([
...@@ -182,7 +180,7 @@ class SentencePredictionTask(FairseqTask): ...@@ -182,7 +180,7 @@ class SentencePredictionTask(FairseqTask):
sort_order=[shuffle], sort_order=[shuffle],
) )
print(f"| Loaded {split} with #samples: {len(dataset)}") print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset self.datasets[split] = dataset
return self.datasets[split] return self.datasets[split]
......
...@@ -9,7 +9,6 @@ Train a new model on one or across multiple GPUs. ...@@ -9,7 +9,6 @@ Train a new model on one or across multiple GPUs.
import collections import collections
import math import math
import os
import random import random
import torch import torch
...@@ -258,7 +257,7 @@ def get_valid_stats(trainer, args, extra_meters=None): ...@@ -258,7 +257,7 @@ def get_valid_stats(trainer, args, extra_meters=None):
stats['ppl'] = utils.get_perplexity(nll_loss.avg) stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates() stats['num_updates'] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, 'best'): if hasattr(checkpoint_utils.save_checkpoint, 'best'):
key = f'best_{args.best_checkpoint_metric}' key = 'best_{0}'.format(args.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min best_function = max if args.maximize_best_checkpoint_metric else min
current_metric = None current_metric = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment