"vscode:/vscode.git/clone" did not exist on "fa31704420c37f2abee2acfe384d3310561a83b9"
Commit c0f05c10 authored by hepj's avatar hepj
Browse files

更新transformer代码

parent c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train a network across multiple GPUs.
"""
import math
from collections import defaultdict
from itertools import chain
import torch
import torch.nn.functional as F
from torch.cuda import amp
from apex.parallel import DistributedDataParallel as DDP
from fairseq import distributed_utils, optim, utils
from fairseq.optim import lr_scheduler
from fairseq.meters import TimeMeter, AverageMeter
from fairseq.criterions import CRITERION_REGISTRY
import dllogger as DLLogger
class DDPTrainer():
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def __init__(self, args, model):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
self.model = model.cuda()
self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda()
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15)
if self.args.distributed_world_size > 1:
self.model = DDP(model)
self._buffered_stats = defaultdict(lambda: [])
self._num_updates = 0
self._optim_history = None
self.throughput_meter = TimeMeter()
self.avg_loss_meter = AverageMeter()
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(
filename, self.args, self.get_model(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file."""
extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.get_model())
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
#self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self._num_updates = last_optim['num_updates']
return extra_state
def train_step(self, sample, update_params=True, last_step=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.model.train()
if isinstance(self.model, DDP):
if last_step:
self.model.disable_allreduce()
else:
self.model.enable_allreduce()
# forward and backward pass
sample = self._prepare_sample(sample)
loss, oom_fwd = self._forward(sample)
# If this is a last batch forward pass is skipped on some workers
# Batch with sample_size 0 is not accounted for in weighted loss
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
'loss': utils.item(loss.data) if loss is not None else 0,
}
sample_size = sample['ntokens'] if sample is not None else 0
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params and not last_step:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
ooms = ooms_fwd + ooms_bwd # this is always <= distributed_world_size
if ooms == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return
# aggregate stats and logging outputs
grad_denom = sum(sample_sizes)
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
p.grad /= grad_denom
self._opt()
# Handle logging
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
self.throughput_meter.update(ntokens)
info_log_data = {
'tokens/s': self.throughput_meter.avg,
'tokens': ntokens,
'loss': sum(log.get('loss', 0) for log in logging_outputs) / ntokens / math.log(2)
}
self.avg_loss_meter.update(info_log_data['loss'])
debug_log_data = {
'batch_size': sum(log.get('nsentences', 0) for log in logging_outputs),
'lr': self.get_lr(),
'grad_denom': grad_denom,
'updates': 1
}
DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0)
DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1)
self.clear_buffered_stats()
def _forward(self, sample):
loss = None
oom = 0
try:
if sample is not None:
with amp.autocast(enabled=self.args.amp):
# calculate loss and sample size
logits, _ = self.model(**sample['net_input'])
target = sample['target']
probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = self.criterion(probs, target)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
self.args.distributed_rank), force=True)
oom = 1
loss = None
else:
raise e
return loss, oom
def _backward(self, loss):
oom = 0
if loss is not None:
try:
self.scaler.scale(loss).backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
self.args.distributed_rank), force=True)
oom = 1
self.zero_grad()
else:
raise e
return oom
def _opt(self):
# take an optimization step
self.scaler.step(self.optimizer.optimizer)
self.scaler.update()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
self.model.eval()
# forward pass
sample = self._prepare_sample(sample)
with torch.no_grad():
loss, oom_fwd = self._forward(sample)
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
loss = loss.item() if loss is not None else 0
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
losses, logging_outputs = zip(*distributed_utils.all_gather_list(
(loss, logging_output)
))
else:
losses = [loss]
logging_outputs = [logging_output]
weight = sum(log.get('ntokens', 0) for log in logging_outputs)
scaled_loss = sum(losses) / weight / math.log(2)
return scaled_loss
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
def get_throughput_meter(self):
"""Get the throughput meter"""
return self.throughput_meter
def get_model(self):
"""Get the model replica."""
return self.model.module if isinstance(self.model, DDP) else self.model
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def _prepare_sample(self, sample):
if not sample:
return None
return utils.move_to_cuda(sample)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import os
import socket
import torch.distributed
from fairseq import utils
def is_master(args):
return args.distributed_rank == 0
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
print("| distributed env init. MASTER_ADDR: " + os.environ['MASTER_ADDR'] +
", MASTER_PORT: " + os.environ['MASTER_PORT'] +
", WORLD_SIZE: " + os.environ['WORLD_SIZE'] + ", RANK: " + os.environ['RANK'], flush=True)
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method='env://')
print("| distributed init done!", flush=True)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = torch.distributed.get_rank()
args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
suppress_output(args)
print('| initialized host {} as rank {} and device id {}'
.format(socket.gethostname(), args.distributed_rank, args.device_id))
return args.distributed_rank
def suppress_output(main_args):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print_master(*args, **kwargs):
if 'force' in kwargs:
kwargs.pop('force')
builtin_print(*args, **kwargs)
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
builtin_print(*args, **kwargs)
if is_master(main_args):
__builtin__.print = print_master
else:
__builtin__.print = print
def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != len(all_gather_list._in_buffer):
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255 * 256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size + 2].tolist()))
)
return result
import os
import atexit
import time
import itertools
from collections import OrderedDict
import dllogger
from dllogger import Backend, JSONStreamBackend
from tensorboardX import SummaryWriter
class AverageMeter():
def __init__(self):
self.reset()
def reset(self):
self.updated = False
self.avg = 0
self.sum = 0
self.count = 0
def update(self, value):
self.updated = True
if isinstance(value, (tuple, list)):
val = value[0]
n = value[1]
else:
val = value
n = 1
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@property
def value(self):
return self.avg
class PerformanceMeter():
def __init__(self):
self.reset()
def reset(self):
self.updated = False
self.start = time.time()
self.n = 0
def update(self, val=1):
self.updated = True
self.n += val
@property
def value(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return time.time() - self.start
METRIC = {'average': AverageMeter, 'performance': PerformanceMeter}
class AggregatorBackend(Backend):
def __init__(self, verbosity, agg_dict):
super().__init__(verbosity=verbosity)
agg_dict = OrderedDict({k: v if isinstance(v, (tuple, list)) else (v,) for k, v in agg_dict.items()})
self.metrics = OrderedDict({k: [METRIC[x]() for x in v] for k, v in agg_dict.items()})
self.metrics.flushed = True
self.step = 0
self.epoch = 0
self.start_time = time.time()
@property
def log_level(self):
return self._log_level
def metadata(self, timestamp, elapsedtime, metric, metadata):
pass
def _reset_perf_meter(self, name):
for agg in self.metrics[name]:
if isinstance(agg, PerformanceMeter):
agg.reset()
def reset_perf_meters(self):
for name in self.metrics.keys():
self._reset_perf_meter(name)
def log(self, timestamp, elapsedtime, step, data):
self.step = step
if 'epoch' in data.keys():
self.epoch = data['epoch']
for k, v in data.items():
if k not in self.metrics.keys():
continue
self.metrics.flushed = False
for ag in self.metrics[k]:
ag.update(v)
def flush(self):
if self.metrics.flushed:
return
result_string = 'Transformer | epoch {} | step {} |'.format(self.epoch, self.step)
for name, aggregators in self.metrics.items():
for agg in aggregators:
if not agg.updated:
continue
if isinstance(agg, AverageMeter):
_name = 'avg ' + name
elif isinstance(agg, PerformanceMeter):
_name = name + '/s'
result_string += _name + ' {:.3f} |'.format(agg.value)
agg.reset()
result_string += 'walltime {:.3f} |'.format(time.time() - self.start_time)
self.metrics.flushed = True
print(result_string)
class TensorBoardBackend(Backend):
def __init__(self, verbosity, log_dir):
super().__init__(verbosity=verbosity)
self.summary_writer = SummaryWriter(log_dir=os.path.join(log_dir, 'TB_summary'),
flush_secs=120,
max_queue=200
)
atexit.register(self.summary_writer.close)
@property
def log_level(self):
return self._log_level
def metadata(self, timestamp, elapsedtime, metric, metadata):
pass
def log(self, timestamp, elapsedtime, step, data):
if not isinstance(step, int):
return
for k, v in data.items():
self.summary_writer.add_scalar(k, v, step)
def flush(self):
pass
def setup_logger(args):
aggregator_dict = OrderedDict([
('loss', 'average'),
('weighted_loss', 'average'),
('tokens', ('average', 'performance')),
('updates', 'performance'),
('gnorm', 'average')
])
os.makedirs(args.save_dir, exist_ok=True)
log_path = os.path.join(args.save_dir, args.stat_file)
if os.path.exists(log_path):
for i in itertools.count():
s_fname = args.stat_file.split('.')
fname = '.'.join(s_fname[:-1]) + f'_{i}.' + s_fname[-1] if len(s_fname) > 1 else args.stat_file + f'.{i}'
log_path = os.path.join(args.save_dir, fname)
if not os.path.exists(log_path):
break
if not args.distributed_world_size > 1 or args.distributed_rank == 0:
dllogger.init(backends=[JSONStreamBackend(verbosity=1, filename=log_path),
AggregatorBackend(verbosity=0, agg_dict=aggregator_dict),
TensorBoardBackend(verbosity=1, log_dir=args.save_dir)])
else:
dllogger.init(backends=[])
for k, v in vars(args).items():
dllogger.log(step='PARAMETER', data={k: v}, verbosity=0)
container_setup_info = get_framework_env_vars()
dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0)
dllogger.metadata('loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'TRAIN'})
dllogger.metadata('val_loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'VAL'})
dllogger.metadata('speed', {'unit': 'tokens/s', 'format': ':.3f', 'GOAL': 'MAXIMIZE', 'STAGE': 'TRAIN'})
dllogger.metadata('accuracy', {'unit': 'bleu', 'format': ':.2f', 'GOAL': 'MAXIMIZE', 'STAGE': 'VAL'})
def get_framework_env_vars():
return {
'NVIDIA_PYTORCH_VERSION': os.environ.get('NVIDIA_PYTORCH_VERSION'),
'PYTORCH_VERSION': os.environ.get('PYTORCH_VERSION'),
'CUBLAS_VERSION': os.environ.get('CUBLAS_VERSION'),
'NCCL_VERSION': os.environ.get('NCCL_VERSION'),
'CUDA_DRIVER_VERSION': os.environ.get('CUDA_DRIVER_VERSION'),
'CUDNN_VERSION': os.environ.get('CUDNN_VERSION'),
'CUDA_VERSION': os.environ.get('CUDA_VERSION'),
'NVIDIA_PIPELINE_ID': os.environ.get('NVIDIA_PIPELINE_ID'),
'NVIDIA_BUILD_ID': os.environ.get('NVIDIA_BUILD_ID'),
'NVIDIA_TF32_OVERRIDE': os.environ.get('NVIDIA_TF32_OVERRIDE'),
}
def reset_perf_meters():
for backend in dllogger.GLOBAL_LOGGER.backends:
if isinstance(backend, AggregatorBackend):
backend.reset_perf_meters()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class TimeMeter(object):
"""Computes the average occurrence of some event per second"""
def __init__(self, init=0):
self.reset(init)
def reset(self, init=0):
self.init = init
self.start = time.time()
self.n = 0
self.last_update = time.time()
def update(self, val=1):
self.n += val
self.last_update = time.time()
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.time() - self.start)
@property
def u_avg(self):
return self.n / (self.last_update - self.start)
class StopwatchMeter(object):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self):
self.reset()
self.intervals = []
def start(self):
self.start_time = time.time()
def stop(self, n=1):
if self.start_time is not None:
delta = time.time() - self.start_time
self.intervals.append(delta)
self.sum += delta
self.n += n
self.start_time = None
def reset(self):
self.sum = 0
self.n = 0
self.start_time = None
self.intervals = []
@property
def avg(self):
return self.sum / self.n
def p(self, i):
assert i <= 100
idx = int(len(self.intervals) * i / 100)
return sorted(self.intervals)[idx]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args)
def register_model(name):
"""Decorator to register a new model (e.g., LSTM)."""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.models.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
class FairseqIncrementalDecoder(nn.Module):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import math
import math
import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
import fused_layer_norm_cuda
class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input, weight, bias):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, self.normalized_shape, weight_, bias_, self.eps)
self.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
def backward(self, grad_output):
input_, weight_, bias_, mean, invvar = self.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
weight_, bias_, self.eps)
return grad_input, grad_weight, grad_bias;
class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input):
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, self.normalized_shape, self.eps)
self.save_for_backward(input_, mean, invvar)
return output
def backward(self, grad_output):
input_, mean, invvar = self.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
self.eps)
return grad_input
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction(normalized_shape,eps)(input)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias)
else:
return FusedLayerNormFunction(self.normalized_shape,self.eps)(
input)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
This diff is collapsed.
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .beamable_mm import BeamableMM
from .learned_positional_embedding import LearnedPositionalEmbedding
from .multihead_attention import MultiheadAttention
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [
'BeamableMM',
'LearnedPositionalEmbedding',
'MultiheadAttention',
'SinusoidalPositionalEmbedding',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(positions)
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <torch/torch.h>
#include <vector>
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor strided_batched_gemm(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
//CHECK_INPUT(in_result);
//CHECK_INPUT(batch1);
//CHECK_INPUT(batch2);
AT_ASSERTM(in_result.dim() == 3, "expected 3D tensor");
AT_ASSERTM(batch1.dim() == 3, "expected 3D tensor");
AT_ASSERTM(batch2.dim() == 3, "expected 3D tensor");
AT_ASSERTM(in_result.size(0) == batch1.size(0), "equal number of batches expected");
AT_ASSERTM(in_result.size(0) == batch2.size(0), "equal number of batches expected");
AT_ASSERTM(in_result.size(1) == batch1.size(1), "wrong matrix size");
AT_ASSERTM(in_result.size(2) == batch2.size(2), "wrong matrix size");
AT_ASSERTM(batch1.size(2) == batch2.size(1), "wrong matrix size");
AT_ASSERTM(batch1.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(batch2.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(in_result.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
return strided_batched_gemm_cuda(beta, in_result, alpha, batch1, batch2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_batched_gemm", &strided_batched_gemm, "Special strided batched gemm.");
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
OPTIMIZER_REGISTRY = {}
OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import FairseqOptimizer, register_optimizer
from apex.optimizers.fused_adam import FusedAdam
@register_optimizer('adam')
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = FusedAdam(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default=(0.9, 0.999), nargs=2, type=float, metavar='B1 B2',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': self.args.adam_betas,
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay,
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.optim
class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
pass
@property
def optimizer(self):
"""Return a torch.optim.optimizer.Optimizer instance."""
if not hasattr(self, '_optimizer'):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
return self._optimizer
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise NotImplementedError
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]['lr']
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def state_dict(self):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self.optimizer.load_state_dict(state_dict)
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self.optimizer_config)
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
return self.optimizer.zero_grad()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {}
def build_lr_scheduler(args, optimizer):
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.lr_scheduler.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import torch
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
def get_training_parser():
parser = get_parser('Trainer')
add_dataset_args(parser, train=True, gen=True)
add_distributed_training_args(parser)
add_model_args(parser)
add_optimization_args(parser)
add_checkpoint_args(parser)
add_inference_args(parser)
add_perf_args(parser)
return parser
def get_inference_parser():
parser = get_parser('Generation')
add_dataset_args(parser, gen=True)
add_inference_args(parser)
add_perf_args(parser)
return parser
def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
args.max_positions = (args.max_source_positions, args.max_target_positions)
if hasattr(args, 'target_bleu') and (args.online_eval or args.target_bleu) and not args.remove_bpe:
args.remove_bpe = '@@ '
# Apply architecture configuration.
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args
def get_parser(desc):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--log-interval', type=int, default=500, metavar='N',
help='print aggregated stats and flush json log every N iteration')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--amp', action='store_true',
help='use Automatic Mixed Precision')
parser.add_argument('--stat-file', type=str, default='run_log.json',
help='Name of the file containing DLLogger output')
parser.add_argument('--save-dir', metavar='DIR', default='results',
help='path to save checkpoints and logs')
parser.add_argument('--do-sanity-check', action='store_true',
help='Perform evaluation on test set before running the training')
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default=True, type=bool, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default=False, type=bool, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--pad-sequence', default=1, type=int, metavar='N',
help='Pad sequences to a multiple of N')
if train:
parser.add_argument('data', metavar='DIR', help='path to data directory')
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
group.add_argument('--num-shards', default=1, type=int, metavar='N',
help='shard generation over N shards')
group.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)')
return group
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=os.getenv('LOCAL_RANK', 0), type=int,
help='rank of the current worker')
group.add_argument('--local_rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
help='distributed backend')
group.add_argument('--distributed-init-method', default=None, type=str,
help='typically tcp://hostname:port that will be used to '
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--target-bleu', default=0.0, type=float, metavar='TARGET',
help='force stop training after reaching target bleu')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--update-freq', default=[1], nargs='+', type=int,
help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys())))
group.add_argument('--lr', '--learning-rate', default=[0.25], nargs='+', type=float,
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format(
', '.join(LR_SCHEDULER_REGISTRY.keys())))
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
# Criterion args
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
return group
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
return group
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--file', metavar='FILE', default=None, type=str,
help='path to a file with input data for inference')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_inference_args(parser):
group = parser.add_argument_group('Generation')
add_common_eval_args(group)
group.add_argument('--beam', default=4, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--online-eval', action='store_true',
help='score model at the end of epoch')
group.add_argument('--save-predictions', action='store_true',
help='Save predictions produced with online evaluation')
group.add_argument('--test-cased-bleu', action='store_true',
help='Use cased bleu for online eval')
group.add_argument('--bpe-codes', default=None, type=str, metavar='CODES',
help='file with bpe codes')
group.add_argument('--buffer-size', default=64, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
group.add_argument('--fp16', action='store_true', help='use fp16 precision')
return group
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
group.add_argument(
'--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(),
help='model architecture: {} (default: fconv)'.format(
', '.join(ARCH_MODEL_REGISTRY.keys())),
)
# Criterion definitions can be found under fairseq/criterions/
group.add_argument(
'--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(),
help='training criterion: {} (default: cross_entropy)'.format(
', '.join(CRITERION_REGISTRY.keys())),
)
return group
def add_perf_args(parser):
group = parser.add_argument_group('Performance')
group.add_argument('--fuse-dropout-add', action='store_true',
help='Fuse dropout and residual adds.')
group.add_argument('--fuse-relu-dropout', action='store_true',
help='Fuse Relu and Dropout.')
group.add_argument('--fuse-layer-norm', action='store_true',
help='Use APEX\'s FusedLayerNorm instead of torch.nn.LayerNorm')
return group
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment