"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e92d43feb08e5b0a2b77d30245a33aeb2567f252"
Commit 82a9f923 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix arg formatting in preprocess.py and add fmt control for black formatting (#399)

Summary:
Not switching to Black formatting just yet, but adding fmt: off directives in case we decide to later.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/399

Differential Revision: D13364674

Pulled By: myleott

fbshipit-source-id: a20a11a18be3d583ee30eff770278fb4bd05b93c
parent 0693c351
...@@ -19,8 +19,10 @@ class CompositeLoss(FairseqCriterion): ...@@ -19,8 +19,10 @@ class CompositeLoss(FairseqCriterion):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss') help='underlying criterion to use for the composite loss')
# fmt: on
def __init__(self, args, task): def __init__(self, args, task):
super().__init__(args, task) super().__init__(args, task)
......
...@@ -22,8 +22,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -22,8 +22,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing') help='epsilon for label smoothing, 0 means no label smoothing')
# fmt: on
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -48,6 +48,7 @@ class FConvModel(FairseqModel): ...@@ -48,6 +48,7 @@ class FConvModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -70,6 +71,7 @@ class FConvModel(FairseqModel): ...@@ -70,6 +71,7 @@ class FConvModel(FairseqModel):
help='share input and output embeddings (requires' help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim' ' --decoder-out-embed-dim and --decoder-embed-dim'
' to be equal)') ' to be equal)')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -41,6 +41,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -41,6 +41,7 @@ class FConvModelSelfAtt(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -75,6 +76,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -75,6 +76,7 @@ class FConvModelSelfAtt(FairseqModel):
help='path to load checkpoint from pretrained model') help='path to load checkpoint from pretrained model')
parser.add_argument('--pretrained', type=str, metavar='EXPR', parser.add_argument('--pretrained', type=str, metavar='EXPR',
help='use pretrained model when training [True, ...]') help='use pretrained model when training [True, ...]')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -25,6 +25,7 @@ class LSTMModel(FairseqModel): ...@@ -25,6 +25,7 @@ class LSTMModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -68,6 +69,7 @@ class LSTMModel(FairseqModel): ...@@ -68,6 +69,7 @@ class LSTMModel(FairseqModel):
parser.add_argument('--share-all-embeddings', default=False, action='store_true', parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings' help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -49,6 +49,7 @@ class TransformerModel(FairseqModel): ...@@ -49,6 +49,7 @@ class TransformerModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
...@@ -93,6 +94,7 @@ class TransformerModel(FairseqModel): ...@@ -93,6 +94,7 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -153,6 +155,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -153,6 +155,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', default=0.1, type=float, metavar='D', parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
...@@ -205,6 +208,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -205,6 +208,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='if set, ties the projection weights of adaptive softmax and adaptive input') help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true', parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder') help='use learned positional embeddings in the decoder')
# fmt: on
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -21,10 +21,12 @@ class FairseqAdam(FairseqOptimizer): ...@@ -21,10 +21,12 @@ class FairseqAdam(FairseqOptimizer):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add optimizer-specific arguments to the parser.""" """Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer') help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer') help='epsilon for Adam optimizer')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
......
...@@ -63,6 +63,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -63,6 +63,7 @@ class CosineSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
...@@ -73,6 +74,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -73,6 +74,7 @@ class CosineSchedule(FairseqLRScheduler):
help='factor to grow the length of each period') help='factor to grow the length of each period')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period') help='initial number of updates per period')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
...@@ -102,4 +104,4 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -102,4 +104,4 @@ class CosineSchedule(FairseqLRScheduler):
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
return self.lr return self.lr
\ No newline at end of file
...@@ -27,10 +27,12 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -27,10 +27,12 @@ class FixedSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch') help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
# fmt: on
def get_next_lr(self, epoch): def get_next_lr(self, epoch):
lrs = self.args.lr lrs = self.args.lr
......
...@@ -55,10 +55,12 @@ class InverseSquareRootSchedule(FairseqLRScheduler): ...@@ -55,10 +55,12 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates') help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr') help='initial learning rate during warmup phase; default is args.lr')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
......
...@@ -42,12 +42,14 @@ class TriangularSchedule(FairseqLRScheduler): ...@@ -42,12 +42,14 @@ class TriangularSchedule(FairseqLRScheduler):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--max-lr', required=True, type=float, metavar='LR', parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr') help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)') help='initial number of updates per period (cycle length)')
parser.add_argument('--shrink-min', action='store_true', parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr') help='if set, also shrinks min lr')
# fmt: on
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
......
...@@ -120,6 +120,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -120,6 +120,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
def get_parser(desc, default_task='translation'): def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N batches (when progress bar is disabled)') help='log progress every N batches (when progress bar is disabled)')
...@@ -134,17 +135,16 @@ def get_parser(desc, default_task='translation'): ...@@ -134,17 +135,16 @@ def get_parser(desc, default_task='translation'):
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument('--task', metavar='TASK', default=default_task,
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
choices=TASK_REGISTRY.keys(), help='task')
help='task', # fmt: on
)
return parser return parser
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
# fmt: off
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
...@@ -168,11 +168,13 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -168,11 +168,13 @@ def add_dataset_args(parser, train=False, gen=False):
help='shard generation over N shards') help='shard generation over N shards')
group.add_argument('--shard-id', default=0, type=int, metavar='ID', group.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)') help='id of the shard to generate (id < num_shards)')
# fmt: on
return group return group
def add_distributed_training_args(parser): def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(), default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)') help='total number of GPUs across all nodes (default: all visible GPUs)')
...@@ -196,11 +198,13 @@ def add_distributed_training_args(parser): ...@@ -196,11 +198,13 @@ def add_distributed_training_args(parser):
help='Don\'t shuffle batches between GPUs, this reduces overall ' help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of' 'randomness and may affect precision but avoids the cost of'
're-reading the data') 're-reading the data')
# fmt: on
return group return group
def add_optimization_args(parser): def add_optimization_args(parser):
group = parser.add_argument_group('Optimization') group = parser.add_argument_group('Optimization')
# fmt: off
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch') help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
...@@ -235,12 +239,13 @@ def add_optimization_args(parser): ...@@ -235,12 +239,13 @@ def add_optimization_args(parser):
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)') help='minimum loss scale (for FP16 training)')
# fmt: on
return group return group
def add_checkpoint_args(parser): def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing') group = parser.add_argument_group('Checkpointing')
# fmt: off
group.add_argument('--save-dir', metavar='DIR', default='checkpoints', group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
...@@ -263,10 +268,12 @@ def add_checkpoint_args(parser): ...@@ -263,10 +268,12 @@ def add_checkpoint_args(parser):
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N', group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs') help='validate every N epochs')
# fmt: on
return group return group
def add_common_eval_args(group): def add_common_eval_args(group):
# fmt: off
group.add_argument('--path', metavar='FILE', group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated') help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
...@@ -276,20 +283,24 @@ def add_common_eval_args(group): ...@@ -276,20 +283,24 @@ def add_common_eval_args(group):
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation that were used during model training')
# fmt: on
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation') group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group) add_common_eval_args(group)
# fmt: off
group.add_argument('--output-word-probs', action='store_true', group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output') help='if set, outputs words and their predicted log probabilities to standard output')
group.add_argument('--output-word-stats', action='store_true', group.add_argument('--output-word-stats', action='store_true',
help='if set, outputs word statistics such as word count, average probability, etc') help='if set, outputs word statistics such as word count, average probability, etc')
# fmt: on
def add_generation_args(parser): def add_generation_args(parser):
group = parser.add_argument_group('Generation') group = parser.add_argument_group('Generation')
add_common_eval_args(group) add_common_eval_args(group)
# fmt: off
group.add_argument('--beam', default=5, type=int, metavar='N', group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
...@@ -332,17 +343,21 @@ def add_generation_args(parser): ...@@ -332,17 +343,21 @@ def add_generation_args(parser):
help='strength of diversity penalty for Diverse Beam Search') help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--print-alignment', action='store_true', group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens') help='if set, uses attention feedback to compute and print alignment to source tokens')
# fmt: on
return group return group
def add_interactive_args(parser): def add_interactive_args(parser):
group = parser.add_argument_group('Interactive') group = parser.add_argument_group('Interactive')
# fmt: off
group.add_argument('--buffer-size', default=0, type=int, metavar='N', group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them') help='read this many sentences into a buffer before processing them')
# fmt: on
def add_model_args(parser): def add_model_args(parser):
group = parser.add_argument_group('Model configuration') group = parser.add_argument_group('Model configuration')
# fmt: off
# Model definitions can be found under fairseq/models/ # Model definitions can be found under fairseq/models/
# #
...@@ -351,17 +366,13 @@ def add_model_args(parser): ...@@ -351,17 +366,13 @@ def add_model_args(parser):
# 1) model defaults (lowest priority) # 1) model defaults (lowest priority)
# 2) --arch argument # 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority) # 3) --encoder/decoder-* arguments (highest priority)
group.add_argument( group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
'--arch', '-a', default='fconv', metavar='ARCH', required=True, choices=ARCH_MODEL_REGISTRY.keys(),
choices=ARCH_MODEL_REGISTRY.keys(), help='Model Architecture')
help='Model Architecture',
)
# Criterion definitions can be found under fairseq/criterions/ # Criterion definitions can be found under fairseq/criterions/
group.add_argument( group.add_argument('--criterion', default='cross_entropy', metavar='CRIT',
'--criterion', default='cross_entropy', metavar='CRIT', choices=CRITERION_REGISTRY.keys(),
choices=CRITERION_REGISTRY.keys(), help='Training Criterion')
help='Training Criterion', # fmt: on
)
return group return group
...@@ -65,10 +65,10 @@ for file in os.listdir(os.path.dirname(__file__)): ...@@ -65,10 +65,10 @@ for file in os.listdir(os.path.dirname(__file__)):
if task_name in TASK_REGISTRY: if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group('Task name') group_task = parser.add_argument_group('Task name')
group_task.add_argument( # fmt: off
'--task', metavar=task_name, group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``' help='Enable this task with: ``--task=' + task_name + '``')
) # fmt: on
group_args = parser.add_argument_group('Additional command-line arguments') group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args) TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser globals()[task_name + '_parser'] = parser
...@@ -52,6 +52,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -52,6 +52,7 @@ class LanguageModelingTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', help='path to data directory') parser.add_argument('data', help='path to data directory')
parser.add_argument('--sample-break-mode', parser.add_argument('--sample-break-mode',
choices=['none', 'complete', 'eos'], choices=['none', 'complete', 'eos'],
...@@ -71,6 +72,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -71,6 +72,7 @@ class LanguageModelingTask(FairseqTask):
help='include future target') help='include future target')
parser.add_argument('--past-target', action='store_true', parser.add_argument('--past-target', action='store_true',
help='include past target') help='include past target')
# fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None): def __init__(self, args, dictionary, output_dictionary, targets=None):
super().__init__(args) super().__init__(args)
......
...@@ -47,6 +47,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -47,6 +47,7 @@ class MultilingualTranslationTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', metavar='DIR', help='path to data directory') parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
...@@ -64,6 +65,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -64,6 +65,7 @@ class MultilingualTranslationTask(FairseqTask):
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
# fmt: on
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
super().__init__(args) super().__init__(args)
......
...@@ -43,6 +43,7 @@ class TranslationTask(FairseqTask): ...@@ -43,6 +43,7 @@ class TranslationTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)') parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language') help='source language')
...@@ -60,6 +61,7 @@ class TranslationTask(FairseqTask): ...@@ -60,6 +61,7 @@ class TranslationTask(FairseqTask):
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int, parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset') help='amount to upsample primary dataset')
# fmt: on
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
super().__init__(args) super().__init__(args)
......
...@@ -23,89 +23,45 @@ from multiprocessing import Pool, Manager, Process ...@@ -23,89 +23,45 @@ from multiprocessing import Pool, Manager, Process
def get_parser(): def get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( # fmt: off
"-s", "--source-lang", default=None, metavar="SRC", help="source language" parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
) help="source language")
parser.add_argument( parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
"-t", "--target-lang", default=None, metavar="TARGET", help="target language" help="target language")
) parser.add_argument("--trainpref", metavar="FP", default=None,
parser.add_argument( help="train file prefix")
"--trainpref", metavar="FP", default=None, help="train file prefix" parser.add_argument("--validpref", metavar="FP", default=None,
) help="comma separated, valid file prefixes")
parser.add_argument( parser.add_argument("--testpref", metavar="FP", default=None,
"--validpref", help="comma separated, test file prefixes")
metavar="FP", parser.add_argument("--destdir", metavar="DIR", default="data-bin",
default=None, help="destination dir")
help="comma separated, valid file prefixes", parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
) help="map words appearing less than threshold times to unknown")
parser.add_argument( parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
"--testpref", help="map words appearing less than threshold times to unknown")
metavar="FP", parser.add_argument("--tgtdict", metavar="FP",
default=None, help="reuse given target dictionary")
help="comma separated, test file prefixes", parser.add_argument("--srcdict", metavar="FP",
) help="reuse given source dictionary")
parser.add_argument( parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
"--destdir", metavar="DIR", default="data-bin", help="destination dir" help="number of target words to retain")
) parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
parser.add_argument( help="number of source words to retain")
"--thresholdtgt", parser.add_argument("--alignfile", metavar="ALIGN", default=None,
metavar="N", help="an alignment file (optional)")
default=0, parser.add_argument("--output-format", metavar="FORMAT", default="binary",
type=int, choices=["binary", "raw"],
help="map words appearing less than threshold times to unknown", help="output format (optional)")
) parser.add_argument("--joined-dictionary", action="store_true",
parser.add_argument( help="Generate joined dictionary")
"--thresholdsrc", parser.add_argument("--only-source", action="store_true",
metavar="N", help="Only process the source language")
default=0, parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
type=int, help="Pad dictionary size to be multiple of N")
help="map words appearing less than threshold times to unknown", parser.add_argument("--workers", metavar="N", default=1, type=int,
) help="number of parallel workers")
parser.add_argument("--tgtdict", metavar="FP", help="reuse given target dictionary") # fmt: on
parser.add_argument("--srcdict", metavar="FP", help="reuse given source dictionary")
parser.add_argument(
"--nwordstgt",
metavar="N",
default=-1,
type=int,
help="number of target words to retain",
)
parser.add_argument(
"--nwordssrc",
metavar="N",
default=-1,
type=int,
help="number of source words to retain",
)
parser.add_argument(
"--alignfile",
metavar="ALIGN",
default=None,
help="an alignment file (optional)",
)
parser.add_argument(
"--output-format",
metavar="FORMAT",
default="binary",
choices=["binary", "raw"],
help="output format (optional)",
)
parser.add_argument(
"--joined-dictionary", action="store_true", help="Generate joined dictionary"
)
parser.add_argument(
"--only-source", action="store_true", help="Only process the source language"
)
parser.add_argument(
"--padding-factor",
metavar="N",
default=8,
type=int,
help="Pad dictionary size to be multiple of N",
)
parser.add_argument(
"--workers", metavar="N", default=1, type=int, help="number of parallel workers"
)
return parser return parser
......
...@@ -19,12 +19,14 @@ from fairseq.data import dictionary ...@@ -19,12 +19,14 @@ from fairseq.data import dictionary
def get_parser(): def get_parser():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
# fmt: off
parser.add_argument('-s', '--sys', default='-', help='system output') parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references') parser.add_argument('-r', '--ref', required=True, help='references')
parser.add_argument('-o', '--order', default=4, metavar='N', parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order') type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true', parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring') help='case-insensitive scoring')
# fmt: on
return parser return parser
...@@ -44,7 +46,7 @@ def main(): ...@@ -44,7 +46,7 @@ def main():
for line in fd.readlines(): for line in fd.readlines():
if args.ignore_case: if args.ignore_case:
yield line.lower() yield line.lower()
else: else:
yield line yield line
def score(fdsys): def score(fdsys):
......
...@@ -86,33 +86,19 @@ def main(): ...@@ -86,33 +86,19 @@ def main():
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
'produce a new checkpoint', 'produce a new checkpoint',
) )
# fmt: off
parser.add_argument( parser.add_argument('--inputs', required=True, nargs='+',
'--inputs', help='Input checkpoint file paths.')
required=True, parser.add_argument('--output', required=True, metavar='FILE',
nargs='+', help='Write the new checkpoint containing the averaged weights to this path.')
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
num_group = parser.add_mutually_exclusive_group() num_group = parser.add_mutually_exclusive_group()
num_group.add_argument( num_group.add_argument('--num-epoch-checkpoints', type=int,
'--num-epoch-checkpoints', help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
type=int, 'and average last this many of them.')
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' num_group.add_argument('--num-update-checkpoints', type=int,
'and average last this many of them.', help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
) 'and average last this many of them.')
num_group.add_argument( # fmt: on
'--num-update-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.',
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -32,6 +32,7 @@ from itertools import zip_longest ...@@ -32,6 +32,7 @@ from itertools import zip_longest
def main(): def main():
parser = argparse.ArgumentParser(description='symmetric alignment builer') parser = argparse.ArgumentParser(description='symmetric alignment builer')
# fmt: off
parser.add_argument('--fast_align_dir', parser.add_argument('--fast_align_dir',
help='path to fast_align build directory') help='path to fast_align build directory')
parser.add_argument('--mosesdecoder_dir', parser.add_argument('--mosesdecoder_dir',
...@@ -47,6 +48,7 @@ def main(): ...@@ -47,6 +48,7 @@ def main():
'in the target language') 'in the target language')
parser.add_argument('--output_dir', parser.add_argument('--output_dir',
help='output directory') help='output directory')
# fmt: on
args = parser.parse_args() args = parser.parse_args()
fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment