# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. # #------------------------------------------------------------------------- # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import torch from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY from fairseq.criterions import CRITERION_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY def get_training_parser(): parser = get_parser('Trainer') add_dataset_args(parser, train=True, gen=True) add_distributed_training_args(parser) add_model_args(parser) add_optimization_args(parser) add_checkpoint_args(parser) add_inference_args(parser) add_perf_args(parser) return parser def get_inference_parser(): parser = get_parser('Generation') add_dataset_args(parser, gen=True) add_inference_args(parser) add_perf_args(parser) return parser def parse_args_and_arch(parser, input_args=None, parse_known=False): # The parser doesn't know about model/criterion/optimizer-specific args, so # we parse twice. First we parse the model/criterion/optimizer, then we # parse a second time after adding the *-specific arguments. # If input_args is given, we will parse those args instead of sys.argv. args, _ = parser.parse_known_args(input_args) # Add model-specific args to parser. if hasattr(args, 'arch'): model_specific_group = parser.add_argument_group( 'Model-specific configuration', # Only include attributes which are explicitly given as command-line # arguments or which have default values. argument_default=argparse.SUPPRESS, ) ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) # Add *-specific args to parser. if hasattr(args, 'optimizer'): OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) if hasattr(args, 'lr_scheduler'): LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) # Parse a second time. if parse_known: args, extra = parser.parse_known_args(input_args) else: args = parser.parse_args(input_args) extra = None # Post-process args. if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: args.max_sentences_valid = args.max_sentences args.max_positions = (args.max_source_positions, args.max_target_positions) if hasattr(args, 'target_bleu') and (args.online_eval or args.target_bleu) and not args.remove_bpe: args.remove_bpe = '@@ ' # Apply architecture configuration. if hasattr(args, 'arch'): ARCH_CONFIG_REGISTRY[args.arch](args) if parse_known: return args, extra else: return args def get_parser(desc): parser = argparse.ArgumentParser( description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc) parser.add_argument('--log-interval', type=int, default=500, metavar='N', help='print aggregated stats and flush json log every N iteration') parser.add_argument('--seed', default=1, type=int, metavar='N', help='pseudo random number generator seed') parser.add_argument('--amp', action='store_true', help='use Automatic Mixed Precision') parser.add_argument('--stat-file', type=str, default='run_log.json', help='Name of the file containing DLLogger output') parser.add_argument('--save-dir', metavar='DIR', default='results', help='path to save checkpoints and logs') parser.add_argument('--do-sanity-check', action='store_true', help='Perform evaluation on test set before running the training') return parser def add_dataset_args(parser, train=False, gen=False): group = parser.add_argument_group('Dataset and data loading') group.add_argument('--max-tokens', type=int, metavar='N', help='maximum number of tokens in a batch') group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', help='maximum number of sentences in a batch') parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language') parser.add_argument('--raw-text', action='store_true', help='load raw text dataset') parser.add_argument('--left-pad-source', default=True, type=bool, metavar='BOOL', help='pad the source on the left (default: True)') parser.add_argument('--left-pad-target', default=False, type=bool, metavar='BOOL', help='pad the target on the left (default: False)') parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', help='max number of tokens in the source sequence') parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', help='max number of tokens in the target sequence') parser.add_argument('--pad-sequence', default=1, type=int, metavar='N', help='Pad sequences to a multiple of N') if train: parser.add_argument('data', metavar='DIR', help='path to data directory') group.add_argument('--train-subset', default='train', metavar='SPLIT', choices=['train', 'valid', 'test'], help='data subset to use for training (train, valid, test)') group.add_argument('--valid-subset', default='valid', metavar='SPLIT', help='comma separated list of data subsets to use for validation' ' (train, valid, valid1, test, test1)') group.add_argument('--max-sentences-valid', type=int, metavar='N', help='maximum number of sentences in a validation batch' ' (defaults to --max-sentences)') if gen: group.add_argument('--gen-subset', default='test', metavar='SPLIT', help='data subset to generate (train, valid, test)') group.add_argument('--num-shards', default=1, type=int, metavar='N', help='shard generation over N shards') group.add_argument('--shard-id', default=0, type=int, metavar='ID', help='id of the shard to generate (id < num_shards)') return group def add_distributed_training_args(parser): group = parser.add_argument_group('Distributed training') group.add_argument('--distributed-world-size', type=int, metavar='N', default=torch.cuda.device_count(), help='total number of GPUs across all nodes (default: all visible GPUs)') group.add_argument('--distributed-rank', default=os.getenv('LOCAL_RANK', 0), type=int, help='rank of the current worker') group.add_argument('--local_rank', default=0, type=int, help='rank of the current worker') group.add_argument('--distributed-backend', default='nccl', type=str, help='distributed backend') group.add_argument('--distributed-init-method', default=None, type=str, help='typically tcp://hostname:port that will be used to ' 'establish initial connetion') group.add_argument('--distributed-port', default=-1, type=int, help='port number (not required if using --distributed-init-method)') group.add_argument('--device-id', default=0, type=int, help='which GPU to use (usually configured automatically)') return group def add_optimization_args(parser): group = parser.add_argument_group('Optimization') group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', help='force stop training at specified epoch') group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', help='force stop training at specified update') group.add_argument('--target-bleu', default=0.0, type=float, metavar='TARGET', help='force stop training after reaching target bleu') group.add_argument('--clip-norm', default=25, type=float, metavar='NORM', help='clip threshold of gradients') group.add_argument('--update-freq', default=[1], nargs='+', type=int, help='update parameters every N_i batches, when in epoch i') # Optimizer definitions can be found under fairseq/optim/ group.add_argument('--optimizer', default='nag', metavar='OPT', choices=OPTIMIZER_REGISTRY.keys(), help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys()))) group.add_argument('--lr', '--learning-rate', default=[0.25], nargs='+', type=float, help='learning rate for the first N epochs; all epochs >N using LR_N' ' (note: this may be interpreted differently depending on --lr-scheduler)') group.add_argument('--momentum', default=0.99, type=float, metavar='M', help='momentum factor') group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', help='weight decay') # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/ group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau', help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format( ', '.join(LR_SCHEDULER_REGISTRY.keys()))) group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', help='minimum learning rate') # Criterion args parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', help='epsilon for label smoothing, 0 means no label smoothing') return group def add_checkpoint_args(parser): group = parser.add_argument_group('Checkpointing') group.add_argument('--restore-file', default='checkpoint_last.pt', help='filename in save-dir from which to load checkpoint') group.add_argument('--save-interval', type=int, default=1, metavar='N', help='save a checkpoint every N epochs') group.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints') group.add_argument('--no-epoch-checkpoints', action='store_true', help='only store last and best checkpoints') group.add_argument('--validate-interval', type=int, default=1, metavar='N', help='validate every N epochs') return group def add_common_eval_args(group): group.add_argument('--path', metavar='FILE', help='path(s) to model file(s), colon separated') group.add_argument('--file', metavar='FILE', default=None, type=str, help='path to a file with input data for inference') group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, help='remove BPE tokens before scoring') group.add_argument('--cpu', action='store_true', help='generate on CPU') group.add_argument('--quiet', action='store_true', help='only print final scores') def add_inference_args(parser): group = parser.add_argument_group('Generation') add_common_eval_args(group) group.add_argument('--beam', default=4, type=int, metavar='N', help='beam size') group.add_argument('--nbest', default=1, type=int, metavar='N', help='number of hypotheses to output') group.add_argument('--max-len-a', default=0, type=float, metavar='N', help=('generate sequences of maximum length ax + b, ' 'where x is the source length')) group.add_argument('--max-len-b', default=200, type=int, metavar='N', help=('generate sequences of maximum length ax + b, ' 'where x is the source length')) group.add_argument('--min-len', default=1, type=float, metavar='N', help=('minimum generation length')) group.add_argument('--no-early-stop', action='store_true', help=('continue searching even after finalizing k=beam ' 'hypotheses; this is more correct, but increases ' 'generation time by 50%%')) group.add_argument('--unnormalized', action='store_true', help='compare unnormalized hypothesis scores') group.add_argument('--no-beamable-mm', action='store_true', help='don\'t use BeamableMM in attention layers') group.add_argument('--lenpen', default=1, type=float, help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') group.add_argument('--unkpen', default=0, type=float, help='unknown word penalty: <0 produces more unks, >0 produces fewer') group.add_argument('--replace-unk', nargs='?', const=True, default=None, help='perform unknown replacement (optionally with alignment dictionary)') group.add_argument('--prefix-size', default=0, type=int, metavar='PS', help='initialize generation by target prefix of given length') group.add_argument('--sampling', action='store_true', help='sample hypotheses instead of using beam search') group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', help='sample from top K likely next words instead of all words') group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', help='temperature for random sampling') group.add_argument('--print-alignment', action='store_true', help='if set, uses attention feedback to compute and print alignment to source tokens') group.add_argument('--online-eval', action='store_true', help='score model at the end of epoch') group.add_argument('--save-predictions', action='store_true', help='Save predictions produced with online evaluation') group.add_argument('--test-cased-bleu', action='store_true', help='Use cased bleu for online eval') group.add_argument('--bpe-codes', default=None, type=str, metavar='CODES', help='file with bpe codes') group.add_argument('--buffer-size', default=64, type=int, metavar='N', help='read this many sentences into a buffer before processing them') group.add_argument('--fp16', action='store_true', help='use fp16 precision') return group def add_model_args(parser): group = parser.add_argument_group('Model configuration') # Model definitions can be found under fairseq/models/ # # The model architecture can be specified in several ways. # In increasing order of priority: # 1) model defaults (lowest priority) # 2) --arch argument group.add_argument( '--arch', '-a', default='fconv', metavar='ARCH', required=True, choices=ARCH_MODEL_REGISTRY.keys(), help='model architecture: {} (default: fconv)'.format( ', '.join(ARCH_MODEL_REGISTRY.keys())), ) # Criterion definitions can be found under fairseq/criterions/ group.add_argument( '--criterion', default='cross_entropy', metavar='CRIT', choices=CRITERION_REGISTRY.keys(), help='training criterion: {} (default: cross_entropy)'.format( ', '.join(CRITERION_REGISTRY.keys())), ) return group def add_perf_args(parser): group = parser.add_argument_group('Performance') group.add_argument('--fuse-dropout-add', action='store_true', help='Fuse dropout and residual adds.') group.add_argument('--fuse-relu-dropout', action='store_true', help='Fuse Relu and Dropout.') group.add_argument('--fuse-layer-norm', action='store_true', help='Use APEX\'s FusedLayerNorm instead of torch.nn.LayerNorm') return group