Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
pass
@property
def optimizer(self):
"""Return a torch.optim.optimizer.Optimizer instance."""
if not hasattr(self, '_optimizer'):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
return self._optimizer
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise NotImplementedError
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]['lr']
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def state_dict(self):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self.optimizer.load_state_dict(state_dict)
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self.optimizer_config)
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
return self.optimizer.zero_grad()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {}
def build_lr_scheduler(args, optimizer):
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.lr_scheduler.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .. import FairseqOptimizer
class FairseqLRScheduler(object):
def __init__(self, args, optimizer):
super().__init__()
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError('optimizer must be an instance of FairseqOptimizer')
self.args = args
self.optimizer = optimizer
self.best = None
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
pass
def state_dict(self):
"""Return the LR scheduler state dict."""
return {'best': self.best}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.best = state_dict['best']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
if self.best is None:
self.best = val_loss
else:
self.best = min(self.best, val_loss)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('inverse_sqrt')
class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates)
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with inverse_sqrt.'
' Consider --lr-scheduler=fixed instead.'
)
warmup_end_lr = args.lr[0]
if args.warmup_init_lr < 0:
args.warmup_init_lr = warmup_end_lr
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * args.warmup_updates**0.5
# initial learning rate
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
else:
self.lr = self.decay_factor * num_updates**-0.5
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.optim.optimizer import Optimizer, required
from . import FairseqOptimizer, register_optimizer
@register_optimizer('nag')
class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = NAG(params, **self.optimizer_config)
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
class NAG(Optimizer):
def __init__(self, params, lr=required, momentum=0, weight_decay=0):
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
lr = group['lr']
lr_old = group.get('lr_old', lr)
lr_correct = lr / lr_old
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.clone().zero_()
buf = param_state['momentum_buffer']
if weight_decay != 0:
p.data.mul_(1 - lr * weight_decay)
p.data.add_(momentum * momentum * lr_correct, buf)
p.data.add_(-(1 + momentum) * lr, d_p)
buf.mul_(momentum * lr_correct).add_(-lr, d_p)
group['lr_old'] = lr
return loss
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('sgd')
class SGD(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import os
import torch
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
def get_training_parser(default_task='translation'):
parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True, gen=True)
add_distributed_training_args(parser)
add_model_args(parser)
add_optimization_args(parser)
add_checkpoint_args(parser)
add_generation_args(parser)
add_perf_args(parser)
return parser
def get_generation_parser(interactive=False, default_task='translation'):
parser = get_parser('Generation', default_task)
add_dataset_args(parser, gen=True)
add_generation_args(parser)
add_perf_args(parser)
if interactive:
add_interactive_args(parser)
return parser
def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(type, x))
except TypeError:
return [type(x)]
def eval_bool(x, default=False):
if x is None:
return default
try:
return bool(eval(x))
except TypeError:
return default
def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
if hasattr(args, 'criterion'):
CRITERION_REGISTRY[args.criterion].add_args(parser)
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'task'):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
# Apply architecture configuration.
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args
def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--profile', type=int, default=None)
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task)
)
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', help='maximum number of sentences in a batch')
group.add_argument('--source_lang', '--source_lang', type=str, metavar='N', help='Source language')
group.add_argument('--target_lang', '--target_lang', type=str, metavar='N', help='Target language')
group.add_argument('--bucket_growth_factor', '--bucket_growth_factor', type=float, metavar='N', help='Bucket growth factor')
group.add_argument('--raw_text', action='store_true', help='raw text')
group.add_argument('--batching_scheme', default='reference', help='Batching Scheme',
choices=['v0p5', 'v0p5_better', 'v0p6', 'reference'])
group.add_argument('--batch_multiple_strategy', default='mult_of_sequences', help='The strategy to achieve a batch size that is multiple of some number.',
choices=['mult_of_sequences', 'pad_sequence_to_mult', 'dynamic'])
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
group.add_argument('--num-shards', default=1, type=int, metavar='N',
help='shard generation over N shards')
group.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)')
return group
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
help='distributed backend')
group.add_argument('--distributed-init-method', default=None, type=str,
help='typically tcp://hostname:port that will be used to '
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--enable-global-stats', action='store_true',
help='enable global reduction of logging statistics for debugging')
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=-1, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--target-bleu', default=0.0, type=float, metavar='TARGET',
help='force stop training after reaching target bleu')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys())))
group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Distributed weight update parameters
group.add_argument('--distributed-weight-update', '--dwu', default=0, type=int, metavar='DWU',
help='select distributed weight update strategy')
group.add_argument('--dwu-group-size', '--dwugs', default=0, type=int, metavar='DWUGS',
help='distributed weight update group size. If arg is 0, defaults to one node')
group.add_argument('--dwu-num-blocks', '--dwunb', default=8, type=int, metavar='DWUNB',
help='number of blocks in dwu scheme')
group.add_argument('--dwu-num-chunks', '--dwunc', default=8, type=int, metavar='DWUNC',
help='number of chunks in dwu scheme')
group.add_argument('--dwu-num-rs-pg', '--dwurspg', default=2, type=int, metavar='DWURSPG',
help='number of reduction-scatter streams in dwu scheme')
group.add_argument('--dwu-num-ar-pg', '--dwuarpg', default=4, type=int, metavar='DWUARPG',
help='number of all-reduce streams in dwu scheme')
group.add_argument('--dwu-num-ag-pg', '--dwuagpg', default=2, type=int, metavar='DWUAGPG',
help='number of all-gather streams in dwu scheme')
group.add_argument('--dwu-full-pipeline', action='store_true',
help='whether to do full or partial pipeline')
group.add_argument('--dwu-overlap-reductions', action='store_true',
help='whether to overlap reductions with backprop')
group.add_argument('--dwu-compute-L2-grad-norm', action='store_true',
help='whether to compute L2 grad norm')
group.add_argument('--dwu-flat-mt', action='store_true',
help='whether to flatten gradients with multi tensor scale')
group.add_argument('--dwu-e5m2-allgather', action='store_true',
help='do allgather with e5m2 floats')
group.add_argument('--dwu-do-not-flatten-model', action='store_true',
help='whether it is allowed to flatten model parameters')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format(
', '.join(LR_SCHEDULER_REGISTRY.keys())))
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)')
# Parallel backward + all-reduce optimization
group.add_argument('--enable-parallel-backward-allred-opt', action='store_true',
help='enable all-reduce of w-gradients in parallel with backward propagation (only for FP16 training)')
group.add_argument('--parallel-backward-allred-cuda-nstreams', type=int, default=1, metavar='N',
help='num of CUDA streams used for parallel all-reduce')
group.add_argument('--parallel-backward-allred-opt-threshold', type=int, default=0, metavar='N',
help='min num of contiguous gradient elements before all-reduce is triggered')
group.add_argument('--enable-parallel-backward-allred-opt-correctness-check', action='store_true',
help='compare w-gradient values obtained doing all-reduce in parallel vs. at the end')
group.add_argument('--dataloader-num-workers', type=int, default=1, metavar='N',
help='num subprocesses for train data loader')
group.add_argument('--enable-dataloader-pin-memory', action='store_true',
help='enable pin_memory for train data loader')
return group
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
return group
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
add_common_eval_args(group)
group.add_argument('--beam', default=4, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
group.add_argument('--online-eval', action='store_true',
help='score model at the end of epoch')
group.add_argument('--log-translations', action='store_true',
help='save translations generated by online eval ')
group.add_argument('--ignore-case', action='store_true',
help='ignore case druing online eval')
return group
def add_interactive_args(parser):
group = parser.add_argument_group('Interactive')
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
group.add_argument(
'--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(),
help='model architecture: {} (default: fconv)'.format(
', '.join(ARCH_MODEL_REGISTRY.keys())),
)
# Criterion definitions can be found under fairseq/criterions/
group.add_argument(
'--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(),
help='training criterion: {} (default: cross_entropy)'.format(
', '.join(CRITERION_REGISTRY.keys())),
)
return group
def add_perf_args(parser):
group = parser.add_argument_group('Performance')
group.add_argument('--multihead-attn-impl', default='default',
choices=['default', 'fast', 'fast_with_lyrnrm_and_dropoutadd'],
help='Multihead Attention implementations.')
parser.add_argument('--time-step', action='store_true', help='Time the performance of a step.')
return group
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
from collections import OrderedDict
import json
from numbers import Number
import sys
from tqdm import tqdm
from fairseq.meters import AverageMeter
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
args.log_format = no_progress_bar if args.no_progress_bar else default
if args.log_format == 'tqdm' and not sys.stderr.isatty():
args.log_format = 'simple'
if args.log_format == 'json':
bar = json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'simple':
bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'tqdm':
bar = tqdm_progress_bar(iterator, epoch, prefix)
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
return bar
class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += '| epoch {:03d}'.format(epoch)
if prefix is not None:
self.prefix += ' | {}'.format(prefix)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self):
raise NotImplementedError
def log(self, stats):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats):
"""Print end-of-epoch stats."""
raise NotImplementedError
def _str_commas(self, stats):
return ', '.join(key + '=' + stats[key].strip()
for key in stats.keys())
def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
return postfix
class json_progress_bar(progress_bar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch - 1 + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print(json.dumps(stats), flush=True)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
"""Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True)
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix
class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
"""Print end-of-epoch stats."""
pass
class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix),
flush=True)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print('{} | {}'.format(self.prefix, postfix), flush=True)
class tqdm_progress_bar(progress_bar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix))
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
stop_early: Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores.
normalize_scores: Normalize scores by the length of the output.
"""
self.models = models
self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
max_decoder_len -= 1 # we define maxlen not including the EOS marker
self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def cuda(self):
for model in self.models:
model.cuda()
return self
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s:
continue
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
input['src_tokens'],
input['src_lengths'],
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data):
# remove padding
src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = []
incremental_states = {}
for model in self.models:
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
incremental_states[model] = {}
else:
incremental_states[model] = None
# compute the encoder output for each beam
encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
)
encoder_outs.append(encoder_out)
# initialize buffers
scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
attn, attn_buf = None, None
nonpad_idxs = None
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfinalized_scores=None):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == maxlen or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= ((maxlen + 5) / 6) ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (((step + 1) + 5) / 6) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo():
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
'score': score,
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]['score']:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]['idx']
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
worst_finalized[sent] = {
'score': s['score'],
'idx': idx,
}
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs)
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams')
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen:
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
if self.sampling_topk > 0:
values, indices = probs[:, 2:].topk(self.sampling_topk)
exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
cand_indices.add_(2)
else:
exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams, rounding_mode="trunc")
cand_indices.fmod_(self.vocab_size)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
probs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
)
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)],
out=active_mask,
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos)
)
active_bbsz_idx = buffer('active_bbsz_idx')
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
def _decode(self, tokens, encoder_outs, incremental_states):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if avg_probs is None:
avg_probs = probs
else:
avg_probs.add_(probs)
if attn is not None:
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs.div_(len(self.models))
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(self, models, tgt_dict):
self.models = models
self.pad = tgt_dict.pad()
def cuda(self):
for model in self.models:
model.cuda()
return self
def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations."""
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None:
timer.start()
pos_scores, attn = self.score(s)
for i, id in enumerate(s['id'].data):
# remove padding from ref
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len
if attn is not None:
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{
'tokens': ref,
'score': score_i,
'attention': attn_i,
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
def score(self, sample):
"""Score a batch of translations."""
net_input = sample['net_input']
# compute scores for each model in the ensemble
avg_probs = None
avg_attn = None
for model in self.models:
with torch.no_grad():
model.eval()
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data
if avg_probs is None:
avg_probs = probs
else:
avg_probs.add_(probs)
if attn is not None:
attn = attn.data
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs.div_(len(self.models))
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(self.models))
avg_probs = avg_probs.gather(
dim=2,
index=sample['target'].data.unsqueeze(-1),
)
return avg_probs.squeeze(2), avg_attn
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args):
return TASK_REGISTRY[args.task].setup_task(args)
def register_task(name):
"""Decorator to register a new task."""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError('Cannot register duplicate task ({})'.format(name))
if not issubclass(cls, FairseqTask):
raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment