Commit c0f05c10 authored by hepj's avatar hepj
Browse files

更新transformer代码

parent c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train a network across multiple GPUs.
"""
import math
from collections import defaultdict
from itertools import chain
import torch
import torch.nn.functional as F
from torch.cuda import amp
from apex.parallel import DistributedDataParallel as DDP
from fairseq import distributed_utils, optim, utils
from fairseq.optim import lr_scheduler
from fairseq.meters import TimeMeter, AverageMeter
from fairseq.criterions import CRITERION_REGISTRY
import dllogger as DLLogger
class DDPTrainer():
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def __init__(self, args, model):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
self.model = model.cuda()
self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda()
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15)
if self.args.distributed_world_size > 1:
self.model = DDP(model)
self._buffered_stats = defaultdict(lambda: [])
self._num_updates = 0
self._optim_history = None
self.throughput_meter = TimeMeter()
self.avg_loss_meter = AverageMeter()
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(
filename, self.args, self.get_model(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file."""
extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.get_model())
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
#self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self._num_updates = last_optim['num_updates']
return extra_state
def train_step(self, sample, update_params=True, last_step=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.model.train()
if isinstance(self.model, DDP):
if last_step:
self.model.disable_allreduce()
else:
self.model.enable_allreduce()
# forward and backward pass
sample = self._prepare_sample(sample)
loss, oom_fwd = self._forward(sample)
# If this is a last batch forward pass is skipped on some workers
# Batch with sample_size 0 is not accounted for in weighted loss
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
'loss': utils.item(loss.data) if loss is not None else 0,
}
sample_size = sample['ntokens'] if sample is not None else 0
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params and not last_step:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
ooms = ooms_fwd + ooms_bwd # this is always <= distributed_world_size
if ooms == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return
# aggregate stats and logging outputs
grad_denom = sum(sample_sizes)
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
p.grad /= grad_denom
self._opt()
# Handle logging
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
self.throughput_meter.update(ntokens)
info_log_data = {
'tokens/s': self.throughput_meter.avg,
'tokens': ntokens,
'loss': sum(log.get('loss', 0) for log in logging_outputs) / ntokens / math.log(2)
}
self.avg_loss_meter.update(info_log_data['loss'])
debug_log_data = {
'batch_size': sum(log.get('nsentences', 0) for log in logging_outputs),
'lr': self.get_lr(),
'grad_denom': grad_denom,
'updates': 1
}
DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0)
DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1)
self.clear_buffered_stats()
def _forward(self, sample):
loss = None
oom = 0
try:
if sample is not None:
with amp.autocast(enabled=self.args.amp):
# calculate loss and sample size
logits, _ = self.model(**sample['net_input'])
target = sample['target']
probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = self.criterion(probs, target)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
self.args.distributed_rank), force=True)
oom = 1
loss = None
else:
raise e
return loss, oom
def _backward(self, loss):
oom = 0
if loss is not None:
try:
self.scaler.scale(loss).backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
self.args.distributed_rank), force=True)
oom = 1
self.zero_grad()
else:
raise e
return oom
def _opt(self):
# take an optimization step
self.scaler.step(self.optimizer.optimizer)
self.scaler.update()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
self.model.eval()
# forward pass
sample = self._prepare_sample(sample)
with torch.no_grad():
loss, oom_fwd = self._forward(sample)
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
loss = loss.item() if loss is not None else 0
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
losses, logging_outputs = zip(*distributed_utils.all_gather_list(
(loss, logging_output)
))
else:
losses = [loss]
logging_outputs = [logging_output]
weight = sum(log.get('ntokens', 0) for log in logging_outputs)
scaled_loss = sum(losses) / weight / math.log(2)
return scaled_loss
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
def get_throughput_meter(self):
"""Get the throughput meter"""
return self.throughput_meter
def get_model(self):
"""Get the model replica."""
return self.model.module if isinstance(self.model, DDP) else self.model
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def _prepare_sample(self, sample):
if not sample:
return None
return utils.move_to_cuda(sample)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import os
import socket
import torch.distributed
from fairseq import utils
def is_master(args):
return args.distributed_rank == 0
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
print("| distributed env init. MASTER_ADDR: " + os.environ['MASTER_ADDR'] +
", MASTER_PORT: " + os.environ['MASTER_PORT'] +
", WORLD_SIZE: " + os.environ['WORLD_SIZE'] + ", RANK: " + os.environ['RANK'], flush=True)
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method='env://')
print("| distributed init done!", flush=True)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = torch.distributed.get_rank()
args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
suppress_output(args)
print('| initialized host {} as rank {} and device id {}'
.format(socket.gethostname(), args.distributed_rank, args.device_id))
return args.distributed_rank
def suppress_output(main_args):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print_master(*args, **kwargs):
if 'force' in kwargs:
kwargs.pop('force')
builtin_print(*args, **kwargs)
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
builtin_print(*args, **kwargs)
if is_master(main_args):
__builtin__.print = print_master
else:
__builtin__.print = print
def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != len(all_gather_list._in_buffer):
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255 * 256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size + 2].tolist()))
)
return result
import os
import atexit
import time
import itertools
from collections import OrderedDict
import dllogger
from dllogger import Backend, JSONStreamBackend
from tensorboardX import SummaryWriter
class AverageMeter():
def __init__(self):
self.reset()
def reset(self):
self.updated = False
self.avg = 0
self.sum = 0
self.count = 0
def update(self, value):
self.updated = True
if isinstance(value, (tuple, list)):
val = value[0]
n = value[1]
else:
val = value
n = 1
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@property
def value(self):
return self.avg
class PerformanceMeter():
def __init__(self):
self.reset()
def reset(self):
self.updated = False
self.start = time.time()
self.n = 0
def update(self, val=1):
self.updated = True
self.n += val
@property
def value(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return time.time() - self.start
METRIC = {'average': AverageMeter, 'performance': PerformanceMeter}
class AggregatorBackend(Backend):
def __init__(self, verbosity, agg_dict):
super().__init__(verbosity=verbosity)
agg_dict = OrderedDict({k: v if isinstance(v, (tuple, list)) else (v,) for k, v in agg_dict.items()})
self.metrics = OrderedDict({k: [METRIC[x]() for x in v] for k, v in agg_dict.items()})
self.metrics.flushed = True
self.step = 0
self.epoch = 0
self.start_time = time.time()
@property
def log_level(self):
return self._log_level
def metadata(self, timestamp, elapsedtime, metric, metadata):
pass
def _reset_perf_meter(self, name):
for agg in self.metrics[name]:
if isinstance(agg, PerformanceMeter):
agg.reset()
def reset_perf_meters(self):
for name in self.metrics.keys():
self._reset_perf_meter(name)
def log(self, timestamp, elapsedtime, step, data):
self.step = step
if 'epoch' in data.keys():
self.epoch = data['epoch']
for k, v in data.items():
if k not in self.metrics.keys():
continue
self.metrics.flushed = False
for ag in self.metrics[k]:
ag.update(v)
def flush(self):
if self.metrics.flushed:
return
result_string = 'Transformer | epoch {} | step {} |'.format(self.epoch, self.step)
for name, aggregators in self.metrics.items():
for agg in aggregators:
if not agg.updated:
continue
if isinstance(agg, AverageMeter):
_name = 'avg ' + name
elif isinstance(agg, PerformanceMeter):
_name = name + '/s'
result_string += _name + ' {:.3f} |'.format(agg.value)
agg.reset()
result_string += 'walltime {:.3f} |'.format(time.time() - self.start_time)
self.metrics.flushed = True
print(result_string)
class TensorBoardBackend(Backend):
def __init__(self, verbosity, log_dir):
super().__init__(verbosity=verbosity)
self.summary_writer = SummaryWriter(log_dir=os.path.join(log_dir, 'TB_summary'),
flush_secs=120,
max_queue=200
)
atexit.register(self.summary_writer.close)
@property
def log_level(self):
return self._log_level
def metadata(self, timestamp, elapsedtime, metric, metadata):
pass
def log(self, timestamp, elapsedtime, step, data):
if not isinstance(step, int):
return
for k, v in data.items():
self.summary_writer.add_scalar(k, v, step)
def flush(self):
pass
def setup_logger(args):
aggregator_dict = OrderedDict([
('loss', 'average'),
('weighted_loss', 'average'),
('tokens', ('average', 'performance')),
('updates', 'performance'),
('gnorm', 'average')
])
os.makedirs(args.save_dir, exist_ok=True)
log_path = os.path.join(args.save_dir, args.stat_file)
if os.path.exists(log_path):
for i in itertools.count():
s_fname = args.stat_file.split('.')
fname = '.'.join(s_fname[:-1]) + f'_{i}.' + s_fname[-1] if len(s_fname) > 1 else args.stat_file + f'.{i}'
log_path = os.path.join(args.save_dir, fname)
if not os.path.exists(log_path):
break
if not args.distributed_world_size > 1 or args.distributed_rank == 0:
dllogger.init(backends=[JSONStreamBackend(verbosity=1, filename=log_path),
AggregatorBackend(verbosity=0, agg_dict=aggregator_dict),
TensorBoardBackend(verbosity=1, log_dir=args.save_dir)])
else:
dllogger.init(backends=[])
for k, v in vars(args).items():
dllogger.log(step='PARAMETER', data={k: v}, verbosity=0)
container_setup_info = get_framework_env_vars()
dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0)
dllogger.metadata('loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'TRAIN'})
dllogger.metadata('val_loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'VAL'})
dllogger.metadata('speed', {'unit': 'tokens/s', 'format': ':.3f', 'GOAL': 'MAXIMIZE', 'STAGE': 'TRAIN'})
dllogger.metadata('accuracy', {'unit': 'bleu', 'format': ':.2f', 'GOAL': 'MAXIMIZE', 'STAGE': 'VAL'})
def get_framework_env_vars():
return {
'NVIDIA_PYTORCH_VERSION': os.environ.get('NVIDIA_PYTORCH_VERSION'),
'PYTORCH_VERSION': os.environ.get('PYTORCH_VERSION'),
'CUBLAS_VERSION': os.environ.get('CUBLAS_VERSION'),
'NCCL_VERSION': os.environ.get('NCCL_VERSION'),
'CUDA_DRIVER_VERSION': os.environ.get('CUDA_DRIVER_VERSION'),
'CUDNN_VERSION': os.environ.get('CUDNN_VERSION'),
'CUDA_VERSION': os.environ.get('CUDA_VERSION'),
'NVIDIA_PIPELINE_ID': os.environ.get('NVIDIA_PIPELINE_ID'),
'NVIDIA_BUILD_ID': os.environ.get('NVIDIA_BUILD_ID'),
'NVIDIA_TF32_OVERRIDE': os.environ.get('NVIDIA_TF32_OVERRIDE'),
}
def reset_perf_meters():
for backend in dllogger.GLOBAL_LOGGER.backends:
if isinstance(backend, AggregatorBackend):
backend.reset_perf_meters()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class TimeMeter(object):
"""Computes the average occurrence of some event per second"""
def __init__(self, init=0):
self.reset(init)
def reset(self, init=0):
self.init = init
self.start = time.time()
self.n = 0
self.last_update = time.time()
def update(self, val=1):
self.n += val
self.last_update = time.time()
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.time() - self.start)
@property
def u_avg(self):
return self.n / (self.last_update - self.start)
class StopwatchMeter(object):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self):
self.reset()
self.intervals = []
def start(self):
self.start_time = time.time()
def stop(self, n=1):
if self.start_time is not None:
delta = time.time() - self.start_time
self.intervals.append(delta)
self.sum += delta
self.n += n
self.start_time = None
def reset(self):
self.sum = 0
self.n = 0
self.start_time = None
self.intervals = []
@property
def avg(self):
return self.sum / self.n
def p(self, i):
assert i <= 100
idx = int(len(self.intervals) * i / 100)
return sorted(self.intervals)[idx]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args)
def register_model(name):
"""Decorator to register a new model (e.g., LSTM)."""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.models.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
class FairseqIncrementalDecoder(nn.Module):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import math
import math
import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
import fused_layer_norm_cuda
class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input, weight, bias):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, self.normalized_shape, weight_, bias_, self.eps)
self.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
def backward(self, grad_output):
input_, weight_, bias_, mean, invvar = self.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
weight_, bias_, self.eps)
return grad_input, grad_weight, grad_bias;
class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input):
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, self.normalized_shape, self.eps)
self.save_for_backward(input_, mean, invvar)
return output
def backward(self, grad_output):
input_, mean, invvar = self.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
self.eps)
return grad_input
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction(normalized_shape,eps)(input)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias)
else:
return FusedLayerNormFunction(self.normalized_shape,self.eps)(
input)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Dict
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding
)
from . import (
FairseqIncrementalDecoder, register_model,
register_model_architecture,
)
from apex.normalization.fused_layer_norm import FusedLayerNorm
torch.set_printoptions(threshold=500000000, linewidth=1024)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=is_training)
out = residual + out
return out
@torch.jit.script
def jit_relu_dropout(x, prob, is_training):
# type: (Tensor, float, bool) -> Tensor
out = F.threshold(x, 0., 0.)
out = F.dropout(out, p=prob, training=is_training)
return out
@register_model('transformer')
class TransformerModel(nn.Module):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
def __init__(self, encoder, decoder):
super().__init__()
self._is_generation_fast = False
self.encoder = encoder
self.decoder = decoder
@classmethod
def build_model(cls, args):
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
if args.share_all_embeddings:
if args.src_vocab_size != args.tgt_vocab_size:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = Embedding(args.src_vocab_size, args.encoder_embed_dim, args.padding_idx)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = Embedding(args.src_vocab_size, args.encoder_embed_dim, args.padding_idx)
decoder_embed_tokens = Embedding(args.tgt_vocab_size, args.decoder_embed_dim, args.padding_idx)
encoder = TransformerEncoder(args, encoder_embed_tokens)
decoder = TransformerDecoder(args, decoder_embed_tokens)
return TransformerModel(encoder, decoder)
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out, padding_mask = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out, padding_mask)
return decoder_out
class TransformerEncoder(nn.Module):
"""Transformer encoder."""
def __init__(self, args, embed_tokens, left_pad=True):
super().__init__()
self.dropout = args.dropout
self.fuse_dropout_add = args.fuse_dropout_add
self.fuse_relu_dropout = args.fuse_relu_dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = FusedLayerNorm(embed_dim) if args.fuse_layer_norm else nn.LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if self.fuse_dropout_add:
x = x.transpose(0, 1).contiguous()
else:
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
_encoder_padding_mask = None
else:
_encoder_padding_mask = encoder_padding_mask
# encoder layers
for layer in self.layers:
x = layer(x, _encoder_padding_mask)
if self.normalize:
x = self.layer_norm(x)
return x, encoder_padding_mask # x.shape == T x B x C, encoder_padding_mask.shape == B x T
def reorder_encoder_out(self, encoder_out, encoder_padding_mask, new_order):
if encoder_out is not None:
encoder_out = encoder_out.index_select(1, new_order)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
return encoder_out, encoder_padding_mask
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, embed_tokens, no_encoder_attn=False, left_pad=False):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
self.fuse_dropout_add = args.fuse_dropout_add
self.fuse_relu_dropout = args.fuse_relu_dropout
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(args.tgt_vocab_size, embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
else:
self.embed_out = self.embed_tokens.weight
self.normalize = args.decoder_normalize_before
if self.normalize:
self.layer_norm = FusedLayerNorm(embed_dim) if args.fuse_layer_norm else nn.LayerNorm(embed_dim)
def forward(self,
prev_output_tokens: Tensor,
encoder_out: Tensor,
encoder_padding_mask: Tensor,
incremental_state: Optional[Dict[str, Dict[str, Tensor]]]=None):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if self.fuse_dropout_add:
x = x.transpose(0, 1).contiguous()
else:
x = x.transpose(0, 1)
attn = None
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out,
encoder_padding_mask if encoder_padding_mask.any() else None,
incremental_state,
)
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
x = F.linear(x, self.embed_out)
return x, attn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.fuse_dropout_add = args.fuse_dropout_add
self.fuse_relu_dropout = args.fuse_relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.maybe_ln1 = MaybeLayerNorm(self.embed_dim, self.normalize_before, fuse=args.fuse_layer_norm)
self.maybe_ln2 = MaybeLayerNorm(self.embed_dim, self.normalize_before, fuse=args.fuse_layer_norm)
def forward(self, x: Tensor, encoder_padding_mask: Optional[Tensor]):
residual = x
x = self.maybe_ln1(x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x,
mask_future_timesteps=False,
key_padding_mask=encoder_padding_mask,
incremental_state=None,
need_weights=False,
static_kv=False)
if self.fuse_dropout_add and self.training:
x = jit_dropout_add(x, residual, self.dropout, self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_ln1(x, after=True)
residual = x
x = self.maybe_ln2(x, before=True)
if self.fuse_relu_dropout:
x = jit_relu_dropout(self.fc1(x), self.relu_dropout, self.training)
else:
x = F.threshold(self.fc1(x), 0.0, 0.0)
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
if self.fuse_dropout_add and self.training:
x = jit_dropout_add(x, residual, self.dropout, self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_ln2(x, after=True)
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(self, args, no_encoder_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before
self.fuse_dropout_add = args.fuse_dropout_add
self.fuse_relu_dropout = args.fuse_relu_dropout
self.self_attn_layer_norm = MaybeLayerNorm(
self.embed_dim, self.normalize_before, fuse=args.fuse_layer_norm)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.encoder_attn_layer_norm = MaybeLayerNorm(
self.embed_dim, self.normalize_before, fuse=args.fuse_layer_norm)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = MaybeLayerNorm(
self.embed_dim, self.normalize_before, fuse=args.fuse_layer_norm)
self.need_attn = True
def forward(self,
x: Tensor,
encoder_out: Tensor,
encoder_padding_mask: Optional[Tensor],
incremental_state: Optional[Dict[str, Dict[str, Tensor]]]):
residual = x
x = self.self_attn_layer_norm(x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
key_padding_mask=None,
incremental_state=incremental_state,
need_weights=False,
static_kv=False
)
if self.fuse_dropout_add and self.training:
x = jit_dropout_add(x, residual, self.dropout, self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x, after=True)
attn = None
if self.encoder_attn is not None:
residual = x
x = self.encoder_attn_layer_norm(x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
mask_future_timesteps=False,
need_weights=(not self.training and self.need_attn),
)
if self.fuse_dropout_add and self.training:
x = jit_dropout_add(x, residual, self.dropout, self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.encoder_attn_layer_norm(x, after=True)
residual = x
x = self.final_layer_norm(x, before=True)
if self.fuse_relu_dropout:
x = jit_relu_dropout(self.fc1(x), self.relu_dropout, self.training)
else:
x = F.threshold(self.fc1(x), 0.0, 0.0)
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
if self.fuse_dropout_add and self.training:
x = jit_dropout_add(x, residual, self.dropout, self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x, after=True)
return x, attn
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
class MaybeLayerNorm(nn.Module):
def __init__(self, embed_dim, normalize_before, fuse=True):
super().__init__()
self.embed_dim = embed_dim
self.normalize_before = normalize_before
self.ln = FusedLayerNorm(embed_dim) if fuse else nn.LayerNorm(embed_dim)
def forward(self, x: Tensor, before: bool = False, after: bool = False):
assert before ^ after
if after ^ self.normalize_before:
return self.ln(x)
else:
return x
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
return m
@register_model_architecture('transformer', 'transformer')
def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_architecture(args)
@register_model_architecture('transformer', 'transformer_wmt_en_de')
def transformer_wmt_en_de(args):
base_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
def transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('transformer', 'transformer_wmt_en_de_big')
def transformer_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
# default parameters used in tensor2tensor implementation
@register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t')
def transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .beamable_mm import BeamableMM
from .learned_positional_embedding import LearnedPositionalEmbedding
from .multihead_attention import MultiheadAttention
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [
'BeamableMM',
'LearnedPositionalEmbedding',
'MultiheadAttention',
'SinusoidalPositionalEmbedding',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(positions)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import torch
from torch import nn, Tensor
from torch.nn import Parameter
import torch.nn.functional as F
from torch.cuda import amp
from torch.autograd.variable import Variable
import strided_batched_gemm
from fairseq import utils
class QueryLinear(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input, weights_q, scale):
s = Variable(torch.tensor([scale]))
ctx.save_for_backward(input, weights_q, s)
q = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_q, beta=0.0, alpha=s[0])
q = q.view(input.size(0), input.size(1), input.size(2))
return q.detach()
@staticmethod
@amp.custom_bwd
def backward(ctx, q_grad):
input, weights_q, s = ctx.saved_tensors
input = input.view(input.size(0) * input.size(1), input.size(2)).transpose(0, 1)
q = torch.addmm(q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2)),
q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2)),
weights_q.transpose(0, 1), beta=0.0, alpha=s[0])
q = q.view(q_grad.size(0), q_grad.size(1), q_grad.size(2))
q_grad = q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2))
weights_q_grad = torch.addmm(weights_q, input, q_grad, beta=0.0, alpha=s[0])
return q, weights_q_grad, None
class KeyValueLinears(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input, weights_k, weights_v):
ctx.save_for_backward(input, weights_k, weights_v)
k = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_k, beta=0.0, alpha=1.0)
k = k.view(input.size(0), input.size(1), input.size(2))
v = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_v, beta=0.0, alpha=1.0)
v = v.view(input.size(0), input.size(1), input.size(2))
return k.detach(), v.detach()
@staticmethod
@amp.custom_bwd
def backward(ctx, k_grad, v_grad):
input, weights_k, weights_v = ctx.saved_tensors
input = input.view(input.size(0) * input.size(1), input.size(2)).transpose(0, 1)
k = torch.addmm(k_grad.view(k_grad.size(0) * k_grad.size(1), k_grad.size(2)),
k_grad.view(k_grad.size(0) * k_grad.size(1), k_grad.size(2)),
weights_k.transpose(0, 1), beta=0.0)
k_grad = k_grad.view(k_grad.size(0) * k_grad.size(1), k_grad.size(2))
weights_k_grad = torch.mm(input, k_grad)
v = k.addmm_(v_grad.view(v_grad.size(0) * v_grad.size(1), v_grad.size(2)),
weights_v.transpose(0, 1), beta=1.0)
v = v.view(v_grad.size(0), v_grad.size(1), v_grad.size(2))
v_grad = v_grad.view(v_grad.size(0) * v_grad.size(1), v_grad.size(2))
weights_v_grad = torch.mm(input, v_grad)
return v, weights_k_grad, weights_v_grad
class SelfAttentionLinears(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input, weights_q, weights_k, weights_v, scale):
s = Variable(torch.tensor([scale]))
ctx.save_for_backward(input, weights_q, weights_k, weights_v, s)
q = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_q, beta=0.0, alpha=s[0])
q = q.view(input.size(0), input.size(1), input.size(2))
k = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_k, beta=0.0, alpha=1.0)
k = k.view(input.size(0), input.size(1), input.size(2))
v = torch.addmm(input.view(input.size(0) * input.size(1), input.size(2)),
input.view(input.size(0) * input.size(1), input.size(2)),
weights_v, beta=0.0, alpha=1.0)
v = v.view(input.size(0), input.size(1), input.size(2))
return q.detach(), k.detach(), v.detach()
@staticmethod
@amp.custom_bwd
def backward(ctx, q_grad, k_grad, v_grad):
input, weights_q, weights_k, weights_v, s = ctx.saved_tensors
input = input.view(input.size(0) * input.size(1), input.size(2)).transpose(0, 1)
q = torch.addmm(q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2)),
q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2)),
weights_q.transpose(0, 1), beta=0.0, alpha=s[0])
q_grad = q_grad.view(q_grad.size(0) * q_grad.size(1), q_grad.size(2))
weights_q_grad = torch.addmm(weights_q, input, q_grad, beta=0.0, alpha=s[0])
k = q.addmm_(k_grad.view(k_grad.size(0) * k_grad.size(1), k_grad.size(2)),
weights_k.transpose(0, 1), beta=1.0)
k_grad = k_grad.view(k_grad.size(0) * k_grad.size(1), k_grad.size(2))
weights_k_grad = torch.mm(input, k_grad)
v = k.addmm_(v_grad.view(v_grad.size(0) * v_grad.size(1), v_grad.size(2)),
weights_v.transpose(0, 1), beta=1.0)
v = v.view(v_grad.size(0), v_grad.size(1), v_grad.size(2))
v_grad = v_grad.view(v_grad.size(0) * v_grad.size(1), v_grad.size(2))
weights_v_grad = torch.mm(input, v_grad)
return v, weights_q_grad, weights_k_grad, weights_v_grad, None
class StridedBmm1Func(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input1, input2):
ctx.save_for_backward(input1, input2)
output = torch.empty((input1.size(0), input1.size(1), input2.size(2)),
dtype=input1.dtype, device=torch.device('cuda'))
if (input1.dtype == torch.float16) and (input2.dtype == torch.float16):
output = strided_batched_gemm.strided_batched_gemm(0.0, output, 1.0, input1, input2)
else:
output = torch.bmm(input1, input2, out=output)
return output.detach()
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
grad_input1 = torch.empty((input1.size(1), input2.size(0), input1.size(2)),
dtype=input1.dtype, device=torch.device('cuda')).transpose(1, 0)
grad_input2 = torch.empty((input2.size(2), input2.size(0), input2.size(1)),
dtype=input2.dtype, device=torch.device('cuda')).transpose(1, 0)
if (grad_output.dtype == torch.float16) and (input1.dtype == torch.float16) and (input2.dtype == torch.float16):
grad_input1 = strided_batched_gemm.strided_batched_gemm(0.0, grad_input1,
1.0, grad_output,
input2.transpose(1, 2))
grad_input2 = strided_batched_gemm.strided_batched_gemm(0.0, grad_input2,
1.0, grad_output.transpose(1, 2),
input1)
grad_input2 = grad_input2.transpose(1, 2)
else:
grad_input1 = torch.bmm(grad_output, input2.transpose(1, 2), out=grad_input1)
grad_input2 = torch.bmm(grad_output.transpose(1, 2), input1, out=grad_input2).transpose(1, 2)
return grad_input1, grad_input2
class StridedBmm2Func(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input1, input2):
ctx.save_for_backward(input1, input2)
output = torch.empty((input1.size(1), input1.size(0), input2.size(2)),
dtype=input1.dtype, device=torch.device('cuda')).transpose(1, 0)
if (input1.dtype == torch.float16) and (input2.dtype == torch.float16):
output = strided_batched_gemm.strided_batched_gemm(0.0, output, 1.0, input1, input2)
else:
output = torch.bmm(input1, input2, out=output)
return output.detach()
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
grad_input2 = torch.empty((input2.size(1), input2.size(0), input2.size(2)),
dtype=input2.dtype, device=torch.device('cuda')).transpose(1, 0)
grad_input1 = torch.empty((input1.size(0), input1.size(1), input1.size(2)),
dtype=input2.dtype, device=torch.device('cuda'))
if (grad_output.dtype == torch.float16) and (input1.dtype == torch.float16) and (input2.dtype == torch.float16):
grad_input1 = strided_batched_gemm.strided_batched_gemm(0.0, grad_input1,
1.0, grad_output,
input2.transpose(1, 2))
grad_input2 = strided_batched_gemm.strided_batched_gemm(0.0, grad_input2,
1.0, input1.transpose(1, 2),
grad_output)
else:
grad_input1 = torch.bmm(grad_output, input2.transpose(1, 2))
grad_input2 = torch.bmm(input1.transpose(1, 2), grad_output, out=grad_input2)
return grad_input1, grad_input2
def query_linear(input: Tensor, weights_q: Tensor, scale: float):
if not torch.jit.is_scripting():
return QueryLinear.apply(input, weights_q, scale)
else:
q = scale * torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_q)
q = q.view(input.shape)
return q
def key_value_linears(input: Tensor, weights_k: Tensor, weights_v: Tensor):
if not torch.jit.is_scripting():
return KeyValueLinears.apply(input, weights_k, weights_v)
else:
k = torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_k)
k = k.view(input.shape)
v = torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_v)
v = v.view(input.shape)
return k, v
def self_attn_linears(input: Tensor, weights_q: Tensor, weights_k: Tensor, weights_v: Tensor, scale: float):
if not torch.jit.is_scripting():
return SelfAttentionLinears.apply(input, weights_q, weights_k, weights_v, scale)
else:
q = scale * torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_q)
q = q.view(input.shape)
k = torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_k)
k = k.view(input.shape)
v = torch.einsum('ij,jk->ik', input.view(input.size(0) * input.size(1), -1), weights_v)
v = v.view(input.shape)
return q, k, v
def strided_bmm1(input1: Tensor, input2: Tensor):
if not torch.jit.is_scripting():
return StridedBmm1Func.apply(input1, input2)
else:
return torch.einsum('ijk,ikn->ijn', input1, input2)
def strided_bmm2(input1: Tensor, input2: Tensor):
if not torch.jit.is_scripting():
return StridedBmm2Func.apply(input1, input2)
else:
return torch.einsum('ijk,ikn->ijn', input1, input2)
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self._mask = torch.empty(0)
#self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_k = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_v = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
#self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_k = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_v = Parameter(torch.Tensor(embed_dim))
else:
#self.register_parameter('in_proj_bias', None)
self.register_parameter('in_proj_bias_k', None)
self.register_parameter('in_proj_bias_q', None)
self.register_parameter('in_proj_bias_v', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_id = str(id(self))
self.reset_parameters()
def reset_parameters(self):
#nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_k)
nn.init.xavier_uniform_(self.in_proj_weight_v)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias_k is not None:
#nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_k, 0.)
nn.init.constant_(self.in_proj_bias_v, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query: Tensor, key: Tensor, value: Tensor,
mask_future_timesteps: bool,
key_padding_mask: Optional[Tensor],
incremental_state: Optional[Dict[str, Dict[str, Tensor]]],
need_weights: bool,
static_kv: bool):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if torch.jit.is_scripting():
kv_same = torch.equal(key, value)
qkv_same = torch.equal(query, value) and kv_same
else:
qkv_same, kv_same = self._fast_same_check(query, key, value)
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
k = v = query.new_empty(0)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self_attn_linears(query, self.in_proj_weight_q,
self.in_proj_weight_k, self.in_proj_weight_v, self.scaling)
elif kv_same:
# encoder-decoder attention
q = query_linear(query, self.in_proj_weight_q, self.scaling)
if not(saved_state is not None and 'prev_key' in saved_state and static_kv):
k, v = key_value_linears(key, self.in_proj_weight_k, self.in_proj_weight_v)
else:
q = torch.addmm(query.view(query.size(0) * query.size(1),
query.size(2)), query.view(query.size(0) * query.size(1),
query.size(2)), self.in_proj_weight_q, beta=0.0, alpha=self.scaling)
if not(saved_state is not None and 'prev_key' in saved_state and static_kv):
k = F.linear(key, self.in_proj_weight_k, self.in_proj_bias_k)
v = F.linear(value, self.in_proj_weight_v, self.in_proj_bias_v)
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = strided_bmm1(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = strided_bmm2(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = attn_weights.new_empty(0) # Can't set to None because jit script reasons
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=None, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask.size(0) == 0:
#TODO: try torch.new_full instead
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new_empty(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Tensor]]]):
if incremental_state is None or self.cache_id not in incremental_state:
return {}
return incremental_state[self.cache_id]
def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Tensor]]],
buffer: Dict[str, Tensor]):
if incremental_state is not None:
incremental_state[self.cache_id] = buffer
@torch.jit.unused
def _fast_same_check(self, q, k, v):
qkv_same = q.data_ptr() == k.data_ptr() == v.data_ptr()
kv_same = k.data_ptr() == v.data_ptr()
return qkv_same, kv_same
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <torch/torch.h>
#include <vector>
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor strided_batched_gemm(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
//CHECK_INPUT(in_result);
//CHECK_INPUT(batch1);
//CHECK_INPUT(batch2);
AT_ASSERTM(in_result.dim() == 3, "expected 3D tensor");
AT_ASSERTM(batch1.dim() == 3, "expected 3D tensor");
AT_ASSERTM(batch2.dim() == 3, "expected 3D tensor");
AT_ASSERTM(in_result.size(0) == batch1.size(0), "equal number of batches expected");
AT_ASSERTM(in_result.size(0) == batch2.size(0), "equal number of batches expected");
AT_ASSERTM(in_result.size(1) == batch1.size(1), "wrong matrix size");
AT_ASSERTM(in_result.size(2) == batch2.size(2), "wrong matrix size");
AT_ASSERTM(batch1.size(2) == batch2.size(1), "wrong matrix size");
AT_ASSERTM(batch1.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(batch2.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(in_result.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
return strided_batched_gemm_cuda(beta, in_result, alpha, batch1, batch2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_batched_gemm", &strided_batched_gemm, "Special strided batched gemm.");
}
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "THC/THC.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C;
else {
THError("trans must be one of: t, n, c");
return CUBLAS_OP_T;
}
}
void CublasGemm(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
//cublasSetStream(handle, THCState_getCurrentStream(state));
float fAlpha = alpha;
float fBeta = beta;
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
template<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef cutlass::gemm::WmmaGemmTraits<
A_LAYOUT,
B_LAYOUT,
cutlass::Shape<32, 16, 16>,
half,
half,
half,
cutlass::gemm::LinearScaling<float>,
float,
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
typename cutlass::Shape<16, 16, 16>,
SRC_A, //kScalarsPerLdgA_
SRC_B, //kScalarsPerLdgB_
SRC_A, //KScalarsPerLdsA_
SRC_B, //KScalarsPerLdsB_
DST_C, //kScalarsPerLdgCAndStgD_
DST_C/2, //kScalarsPerStsD_
DST_C/2 //kScalarsPerLdsD_
>
WmmaGemmTraits;
typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;
typename Gemm::Params params;
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
batchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params));
}
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) {
//cudaStream_t stream = THCState_getCurrentStream(state);
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
else { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
}
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)
{
int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if(n <= 1)
*ldc = std::max<int64_t>(m, 1);
if(transa_)
{
if(m <= 1)
*lda = std::max<int64_t>(k, 1);
}
else
{
if(k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if(transb_)
{
if(k <= 1)
*ldb = std::max<int64_t>(n, 1);
}
else
{
if(n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
OPTIMIZER_REGISTRY = {}
OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import FairseqOptimizer, register_optimizer
from apex.optimizers.fused_adam import FusedAdam
@register_optimizer('adam')
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = FusedAdam(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default=(0.9, 0.999), nargs=2, type=float, metavar='B1 B2',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': self.args.adam_betas,
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay,
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.optim
class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
pass
@property
def optimizer(self):
"""Return a torch.optim.optimizer.Optimizer instance."""
if not hasattr(self, '_optimizer'):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
return self._optimizer
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise NotImplementedError
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]['lr']
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def state_dict(self):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self.optimizer.load_state_dict(state_dict)
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self.optimizer_config)
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
return self.optimizer.zero_grad()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {}
def build_lr_scheduler(args, optimizer):
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.lr_scheduler.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import torch
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
def get_training_parser():
parser = get_parser('Trainer')
add_dataset_args(parser, train=True, gen=True)
add_distributed_training_args(parser)
add_model_args(parser)
add_optimization_args(parser)
add_checkpoint_args(parser)
add_inference_args(parser)
add_perf_args(parser)
return parser
def get_inference_parser():
parser = get_parser('Generation')
add_dataset_args(parser, gen=True)
add_inference_args(parser)
add_perf_args(parser)
return parser
def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
args.max_positions = (args.max_source_positions, args.max_target_positions)
if hasattr(args, 'target_bleu') and (args.online_eval or args.target_bleu) and not args.remove_bpe:
args.remove_bpe = '@@ '
# Apply architecture configuration.
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args
def get_parser(desc):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--log-interval', type=int, default=500, metavar='N',
help='print aggregated stats and flush json log every N iteration')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--amp', action='store_true',
help='use Automatic Mixed Precision')
parser.add_argument('--stat-file', type=str, default='run_log.json',
help='Name of the file containing DLLogger output')
parser.add_argument('--save-dir', metavar='DIR', default='results',
help='path to save checkpoints and logs')
parser.add_argument('--do-sanity-check', action='store_true',
help='Perform evaluation on test set before running the training')
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default=True, type=bool, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default=False, type=bool, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--pad-sequence', default=1, type=int, metavar='N',
help='Pad sequences to a multiple of N')
if train:
parser.add_argument('data', metavar='DIR', help='path to data directory')
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
if gen:
group.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
group.add_argument('--num-shards', default=1, type=int, metavar='N',
help='shard generation over N shards')
group.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)')
return group
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=os.getenv('LOCAL_RANK', 0), type=int,
help='rank of the current worker')
group.add_argument('--local_rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
help='distributed backend')
group.add_argument('--distributed-init-method', default=None, type=str,
help='typically tcp://hostname:port that will be used to '
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--target-bleu', default=0.0, type=float, metavar='TARGET',
help='force stop training after reaching target bleu')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--update-freq', default=[1], nargs='+', type=int,
help='update parameters every N_i batches, when in epoch i')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(),
help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys())))
group.add_argument('--lr', '--learning-rate', default=[0.25], nargs='+', type=float,
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format(
', '.join(LR_SCHEDULER_REGISTRY.keys())))
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
# Criterion args
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
return group
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
return group
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--file', metavar='FILE', default=None, type=str,
help='path to a file with input data for inference')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_inference_args(parser):
group = parser.add_argument_group('Generation')
add_common_eval_args(group)
group.add_argument('--beam', default=4, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--online-eval', action='store_true',
help='score model at the end of epoch')
group.add_argument('--save-predictions', action='store_true',
help='Save predictions produced with online evaluation')
group.add_argument('--test-cased-bleu', action='store_true',
help='Use cased bleu for online eval')
group.add_argument('--bpe-codes', default=None, type=str, metavar='CODES',
help='file with bpe codes')
group.add_argument('--buffer-size', default=64, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
group.add_argument('--fp16', action='store_true', help='use fp16 precision')
return group
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
group.add_argument(
'--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(),
help='model architecture: {} (default: fconv)'.format(
', '.join(ARCH_MODEL_REGISTRY.keys())),
)
# Criterion definitions can be found under fairseq/criterions/
group.add_argument(
'--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(),
help='training criterion: {} (default: cross_entropy)'.format(
', '.join(CRITERION_REGISTRY.keys())),
)
return group
def add_perf_args(parser):
group = parser.add_argument_group('Performance')
group.add_argument('--fuse-dropout-add', action='store_true',
help='Fuse dropout and residual adds.')
group.add_argument('--fuse-relu-dropout', action='store_true',
help='Fuse Relu and Dropout.')
group.add_argument('--fuse-layer-norm', action='store_true',
help='Use APEX\'s FusedLayerNorm instead of torch.nn.LayerNorm')
return group
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment