Commit 82a9f923 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix arg formatting in preprocess.py and add fmt control for black formatting (#399)

Summary:
Not switching to Black formatting just yet, but adding fmt: off directives in case we decide to later.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/399

Differential Revision: D13364674

Pulled By: myleott

fbshipit-source-id: a20a11a18be3d583ee30eff770278fb4bd05b93c
parent 0693c351
...@@ -19,8 +19,10 @@ class CompositeLoss(FairseqCriterion): ...@@ -19,8 +19,10 @@ class CompositeLoss(FairseqCriterion):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss') help='underlying criterion to use for the composite loss')
# fmt: on
def __init__(self, args, task): def __init__(self, args, task):
super().__init__(args, task) super().__init__(args, task)
......
...@@ -22,8 +22,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -22,8 +22,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing') help='epsilon for label smoothing, 0 means no label smoothing')
# fmt: on
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -48,6 +48,7 @@ class FConvModel(FairseqModel): ...@@ -48,6 +48,7 @@ class FConvModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -70,6 +71,7 @@ class FConvModel(FairseqModel): ...@@ -70,6 +71,7 @@ class FConvModel(FairseqModel):
help='share input and output embeddings (requires' help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim' ' --decoder-out-embed-dim and --decoder-embed-dim'
' to be equal)') ' to be equal)')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -41,6 +41,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -41,6 +41,7 @@ class FConvModelSelfAtt(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -75,6 +76,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -75,6 +76,7 @@ class FConvModelSelfAtt(FairseqModel):
help='path to load checkpoint from pretrained model') help='path to load checkpoint from pretrained model')
parser.add_argument('--pretrained', type=str, metavar='EXPR', parser.add_argument('--pretrained', type=str, metavar='EXPR',
help='use pretrained model when training [True, ...]') help='use pretrained model when training [True, ...]')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -25,6 +25,7 @@ class LSTMModel(FairseqModel): ...@@ -25,6 +25,7 @@ class LSTMModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -68,6 +69,7 @@ class LSTMModel(FairseqModel): ...@@ -68,6 +69,7 @@ class LSTMModel(FairseqModel):
parser.add_argument('--share-all-embeddings', default=False, action='store_true', parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings' help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -49,6 +49,7 @@ class TransformerModel(FairseqModel): ...@@ -49,6 +49,7 @@ class TransformerModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
...@@ -93,6 +94,7 @@ class TransformerModel(FairseqModel): ...@@ -93,6 +94,7 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -153,6 +155,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -153,6 +155,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', default=0.1, type=float, metavar='D', parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
...@@ -205,6 +208,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -205,6 +208,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='if set, ties the projection weights of adaptive softmax and adaptive input') help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true', parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder') help='use learned positional embeddings in the decoder')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -21,10 +21,12 @@ class FairseqAdam(FairseqOptimizer): ...@@ -21,10 +21,12 @@ class FairseqAdam(FairseqOptimizer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add optimizer-specific arguments to the parser.""" """Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer') help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer') help='epsilon for Adam optimizer')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
......
...@@ -63,6 +63,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -63,6 +63,7 @@ class CosineSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
...@@ -73,6 +74,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -73,6 +74,7 @@ class CosineSchedule(FairseqLRScheduler):
help='factor to grow the length of each period') help='factor to grow the length of each period')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period') help='initial number of updates per period')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
......
...@@ -27,10 +27,12 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -27,10 +27,12 @@ class FixedSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch') help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
# fmt: on
def get_next_lr(self, epoch): def get_next_lr(self, epoch):
lrs = self.args.lr lrs = self.args.lr
......
...@@ -55,10 +55,12 @@ class InverseSquareRootSchedule(FairseqLRScheduler): ...@@ -55,10 +55,12 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr') help='initial learning rate during warmup phase; default is args.lr')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
......
...@@ -42,12 +42,14 @@ class TriangularSchedule(FairseqLRScheduler): ...@@ -42,12 +42,14 @@ class TriangularSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--max-lr', required=True, type=float, metavar='LR', parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr') help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)') help='initial number of updates per period (cycle length)')
parser.add_argument('--shrink-min', action='store_true', parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr') help='if set, also shrinks min lr')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
......
...@@ -120,6 +120,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -120,6 +120,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
def get_parser(desc, default_task='translation'): def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N batches (when progress bar is disabled)') help='log progress every N batches (when progress bar is disabled)')
...@@ -134,17 +135,16 @@ def get_parser(desc, default_task='translation'): ...@@ -134,17 +135,16 @@ def get_parser(desc, default_task='translation'):
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument('--task', metavar='TASK', default=default_task,
'--task', metavar='TASK', default=default_task,
choices=TASK_REGISTRY.keys(), choices=TASK_REGISTRY.keys(),
help='task', help='task')
) # fmt: on
return parser return parser
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
# fmt: off
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
...@@ -168,11 +168,13 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -168,11 +168,13 @@ def add_dataset_args(parser, train=False, gen=False):
help='shard generation over N shards') help='shard generation over N shards')
group.add_argument('--shard-id', default=0, type=int, metavar='ID', group.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)') help='id of the shard to generate (id < num_shards)')
# fmt: on
return group return group
def add_distributed_training_args(parser): def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(), default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)') help='total number of GPUs across all nodes (default: all visible GPUs)')
...@@ -196,11 +198,13 @@ def add_distributed_training_args(parser): ...@@ -196,11 +198,13 @@ def add_distributed_training_args(parser):
help='Don\'t shuffle batches between GPUs, this reduces overall ' help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of' 'randomness and may affect precision but avoids the cost of'
're-reading the data') 're-reading the data')
# fmt: on
return group return group
def add_optimization_args(parser): def add_optimization_args(parser):
group = parser.add_argument_group('Optimization') group = parser.add_argument_group('Optimization')
# fmt: off
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch') help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
...@@ -235,12 +239,13 @@ def add_optimization_args(parser): ...@@ -235,12 +239,13 @@ def add_optimization_args(parser):
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)') help='minimum loss scale (for FP16 training)')
# fmt: on
return group return group
def add_checkpoint_args(parser): def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing') group = parser.add_argument_group('Checkpointing')
# fmt: off
group.add_argument('--save-dir', metavar='DIR', default='checkpoints', group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
...@@ -263,10 +268,12 @@ def add_checkpoint_args(parser): ...@@ -263,10 +268,12 @@ def add_checkpoint_args(parser):
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N', group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs') help='validate every N epochs')
# fmt: on
return group return group
def add_common_eval_args(group): def add_common_eval_args(group):
# fmt: off
group.add_argument('--path', metavar='FILE', group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated') help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
...@@ -276,20 +283,24 @@ def add_common_eval_args(group): ...@@ -276,20 +283,24 @@ def add_common_eval_args(group):
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation that were used during model training')
# fmt: on
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation') group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group) add_common_eval_args(group)
# fmt: off
group.add_argument('--output-word-probs', action='store_true', group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output') help='if set, outputs words and their predicted log probabilities to standard output')
group.add_argument('--output-word-stats', action='store_true', group.add_argument('--output-word-stats', action='store_true',
help='if set, outputs word statistics such as word count, average probability, etc') help='if set, outputs word statistics such as word count, average probability, etc')
# fmt: on
def add_generation_args(parser): def add_generation_args(parser):
group = parser.add_argument_group('Generation') group = parser.add_argument_group('Generation')
add_common_eval_args(group) add_common_eval_args(group)
# fmt: off
group.add_argument('--beam', default=5, type=int, metavar='N', group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
...@@ -332,17 +343,21 @@ def add_generation_args(parser): ...@@ -332,17 +343,21 @@ def add_generation_args(parser):
help='strength of diversity penalty for Diverse Beam Search') help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--print-alignment', action='store_true', group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens') help='if set, uses attention feedback to compute and print alignment to source tokens')
# fmt: on
return group return group
def add_interactive_args(parser): def add_interactive_args(parser):
group = parser.add_argument_group('Interactive') group = parser.add_argument_group('Interactive')
# fmt: off
group.add_argument('--buffer-size', default=0, type=int, metavar='N', group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them') help='read this many sentences into a buffer before processing them')
# fmt: on
def add_model_args(parser): def add_model_args(parser):
group = parser.add_argument_group('Model configuration') group = parser.add_argument_group('Model configuration')
# fmt: off
# Model definitions can be found under fairseq/models/ # Model definitions can be found under fairseq/models/
# #
...@@ -351,17 +366,13 @@ def add_model_args(parser): ...@@ -351,17 +366,13 @@ def add_model_args(parser):
# 1) model defaults (lowest priority) # 1) model defaults (lowest priority)
# 2) --arch argument # 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority) # 3) --encoder/decoder-* arguments (highest priority)
group.add_argument( group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
'--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(), choices=ARCH_MODEL_REGISTRY.keys(),
help='Model Architecture', help='Model Architecture')
)
# Criterion definitions can be found under fairseq/criterions/ # Criterion definitions can be found under fairseq/criterions/
group.add_argument( group.add_argument('--criterion', default='cross_entropy', metavar='CRIT',
'--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(), choices=CRITERION_REGISTRY.keys(),
help='Training Criterion', help='Training Criterion')
) # fmt: on
return group return group
...@@ -65,10 +65,10 @@ for file in os.listdir(os.path.dirname(__file__)): ...@@ -65,10 +65,10 @@ for file in os.listdir(os.path.dirname(__file__)):
if task_name in TASK_REGISTRY: if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group('Task name') group_task = parser.add_argument_group('Task name')
group_task.add_argument( # fmt: off
'--task', metavar=task_name, group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``' help='Enable this task with: ``--task=' + task_name + '``')
) # fmt: on
group_args = parser.add_argument_group('Additional command-line arguments') group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args) TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser globals()[task_name + '_parser'] = parser
...@@ -52,6 +52,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -52,6 +52,7 @@ class LanguageModelingTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', help='path to data directory') parser.add_argument('data', help='path to data directory')
parser.add_argument('--sample-break-mode', parser.add_argument('--sample-break-mode',
choices=['none', 'complete', 'eos'], choices=['none', 'complete', 'eos'],
...@@ -71,6 +72,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -71,6 +72,7 @@ class LanguageModelingTask(FairseqTask):
help='include future target') help='include future target')
parser.add_argument('--past-target', action='store_true', parser.add_argument('--past-target', action='store_true',
help='include past target') help='include past target')
# fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None): def __init__(self, args, dictionary, output_dictionary, targets=None):
super().__init__(args) super().__init__(args)
......
...@@ -47,6 +47,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -47,6 +47,7 @@ class MultilingualTranslationTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', metavar='DIR', help='path to data directory') parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
...@@ -64,6 +65,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -64,6 +65,7 @@ class MultilingualTranslationTask(FairseqTask):
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
# fmt: on
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
super().__init__(args) super().__init__(args)
......
...@@ -43,6 +43,7 @@ class TranslationTask(FairseqTask): ...@@ -43,6 +43,7 @@ class TranslationTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)') parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language') help='source language')
...@@ -60,6 +61,7 @@ class TranslationTask(FairseqTask): ...@@ -60,6 +61,7 @@ class TranslationTask(FairseqTask):
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int, parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset') help='amount to upsample primary dataset')
# fmt: on
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
super().__init__(args) super().__init__(args)
......
...@@ -23,89 +23,45 @@ from multiprocessing import Pool, Manager, Process ...@@ -23,89 +23,45 @@ from multiprocessing import Pool, Manager, Process
def get_parser(): def get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( # fmt: off
"-s", "--source-lang", default=None, metavar="SRC", help="source language" parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
) help="source language")
parser.add_argument( parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
"-t", "--target-lang", default=None, metavar="TARGET", help="target language" help="target language")
) parser.add_argument("--trainpref", metavar="FP", default=None,
parser.add_argument( help="train file prefix")
"--trainpref", metavar="FP", default=None, help="train file prefix" parser.add_argument("--validpref", metavar="FP", default=None,
) help="comma separated, valid file prefixes")
parser.add_argument( parser.add_argument("--testpref", metavar="FP", default=None,
"--validpref", help="comma separated, test file prefixes")
metavar="FP", parser.add_argument("--destdir", metavar="DIR", default="data-bin",
default=None, help="destination dir")
help="comma separated, valid file prefixes", parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
) help="map words appearing less than threshold times to unknown")
parser.add_argument( parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
"--testpref", help="map words appearing less than threshold times to unknown")
metavar="FP", parser.add_argument("--tgtdict", metavar="FP",
default=None, help="reuse given target dictionary")
help="comma separated, test file prefixes", parser.add_argument("--srcdict", metavar="FP",
) help="reuse given source dictionary")
parser.add_argument( parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
"--destdir", metavar="DIR", default="data-bin", help="destination dir" help="number of target words to retain")
) parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
parser.add_argument( help="number of source words to retain")
"--thresholdtgt", parser.add_argument("--alignfile", metavar="ALIGN", default=None,
metavar="N", help="an alignment file (optional)")
default=0, parser.add_argument("--output-format", metavar="FORMAT", default="binary",
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument(
"--thresholdsrc",
metavar="N",
default=0,
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument("--tgtdict", metavar="FP", help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP", help="reuse given source dictionary")
parser.add_argument(
"--nwordstgt",
metavar="N",
default=-1,
type=int,
help="number of target words to retain",
)
parser.add_argument(
"--nwordssrc",
metavar="N",
default=-1,
type=int,
help="number of source words to retain",
)
parser.add_argument(
"--alignfile",
metavar="ALIGN",
default=None,
help="an alignment file (optional)",
)
parser.add_argument(
"--output-format",
metavar="FORMAT",
default="binary",
choices=["binary", "raw"], choices=["binary", "raw"],
help="output format (optional)", help="output format (optional)")
) parser.add_argument("--joined-dictionary", action="store_true",
parser.add_argument( help="Generate joined dictionary")
"--joined-dictionary", action="store_true", help="Generate joined dictionary" parser.add_argument("--only-source", action="store_true",
) help="Only process the source language")
parser.add_argument( parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
"--only-source", action="store_true", help="Only process the source language" help="Pad dictionary size to be multiple of N")
) parser.add_argument("--workers", metavar="N", default=1, type=int,
parser.add_argument( help="number of parallel workers")
"--padding-factor", # fmt: on
metavar="N",
default=8,
type=int,
help="Pad dictionary size to be multiple of N",
)
parser.add_argument(
"--workers", metavar="N", default=1, type=int, help="number of parallel workers"
)
return parser return parser
......
...@@ -19,12 +19,14 @@ from fairseq.data import dictionary ...@@ -19,12 +19,14 @@ from fairseq.data import dictionary
def get_parser(): def get_parser():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
# fmt: off
parser.add_argument('-s', '--sys', default='-', help='system output') parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references') parser.add_argument('-r', '--ref', required=True, help='references')
parser.add_argument('-o', '--order', default=4, metavar='N', parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order') type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true', parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring') help='case-insensitive scoring')
# fmt: on
return parser return parser
......
...@@ -86,33 +86,19 @@ def main(): ...@@ -86,33 +86,19 @@ def main():
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
'produce a new checkpoint', 'produce a new checkpoint',
) )
# fmt: off
parser.add_argument( parser.add_argument('--inputs', required=True, nargs='+',
'--inputs', help='Input checkpoint file paths.')
required=True, parser.add_argument('--output', required=True, metavar='FILE',
nargs='+', help='Write the new checkpoint containing the averaged weights to this path.')
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
num_group = parser.add_mutually_exclusive_group() num_group = parser.add_mutually_exclusive_group()
num_group.add_argument( num_group.add_argument('--num-epoch-checkpoints', type=int,
'--num-epoch-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.', 'and average last this many of them.')
) num_group.add_argument('--num-update-checkpoints', type=int,
num_group.add_argument(
'--num-update-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.', 'and average last this many of them.')
) # fmt: on
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -32,6 +32,7 @@ from itertools import zip_longest ...@@ -32,6 +32,7 @@ from itertools import zip_longest
def main(): def main():
parser = argparse.ArgumentParser(description='symmetric alignment builer') parser = argparse.ArgumentParser(description='symmetric alignment builer')
# fmt: off
parser.add_argument('--fast_align_dir', parser.add_argument('--fast_align_dir',
help='path to fast_align build directory') help='path to fast_align build directory')
parser.add_argument('--mosesdecoder_dir', parser.add_argument('--mosesdecoder_dir',
...@@ -47,6 +48,7 @@ def main(): ...@@ -47,6 +48,7 @@ def main():
'in the target language') 'in the target language')
parser.add_argument('--output_dir', parser.add_argument('--output_dir',
help='output directory') help='output directory')
# fmt: on
args = parser.parse_args() args = parser.parse_args()
fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment