Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
)
EncoderOut = namedtuple(
"TransformerEncoderOut",
[
"encoder_out", # T x B x C
"encoder_padding_mask", # B x T
"encoder_embedding", # B x T x C
"encoder_states", # List[T x B x C]
],
)
class TransformerEncoderEmbedding(nn.Module):
""" Encoder Embedding + Positional Embedding """
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
if isinstance(embed_tokens, nn.ModuleList):
self.padding_idx = embed_tokens[0].padding_idx
embed_dim = sum(e.embedding_dim for e in embed_tokens)
else:
self.padding_idx = embed_tokens.padding_idx
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward(self, input):
# embed tokens and positions
src_tokens = input[0]
prev_output_tokens = input[2]
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(src_tokens))
embedded = torch.cat(x_embed_list, dim=-1)
else:
embedded = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * embedded
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerEncoderLayerNorm(nn.Module):
"""
Layer norm at the the end of all encoder layers if
args.encoder_enormalize_before = True
"""
def __init__(self, args, embed_dim):
super().__init__()
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input):
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
if self.layer_norm:
x = self.layer_norm(x)
# keeping track of the incremental_state is not supported yet
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerDecoderEmbedding(nn.Module):
""" Decoder Embedding + Positional Embedding """
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = (
sum(e.embedding_dim for e in embed_tokens)
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.embedding_dim
)
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
padding_idx = (
embed_tokens[0].padding_idx
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.padding_idx
)
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
def forward(self, input):
mt_task = False
if isinstance(input, tuple):
if len(input) == 3:
encoder_out = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
incremental_state = None # Hardcoding to avoid passing of None objects
mt_task = True
else:
# HACK for now, need to fix (TODO sidgoyal)
prev_output_tokens = input[0]
# discard "src_lengths"
encoder_out = None
encoder_padding_mask = None
incremental_state = None
else:
prev_output_tokens = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(prev_output_tokens))
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
else:
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
class TransformerDecoderOutputLayer(nn.Module):
def __init__(self, args, embed_tokens, dictionary):
super().__init__()
self.share_input_output_embed = args.share_decoder_input_output_embed
self.embed_tokens = embed_tokens
self.output_embed_dim = args.decoder_output_dim
embed_dim = args.decoder_embed_dim
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
assert not isinstance(embed_tokens, nn.ModuleList)
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_tokens = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(
self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
)
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input, apply_final_proj=True):
if isinstance(input, tuple):
x = input[0]
else:
x = input
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if apply_final_proj:
x = self.output_layer(x)
return x
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
if isinstance(self.embed_tokens, nn.ModuleList):
output = None
for i, emb in enumerate(self.embed_tokens):
sidx = i * emb.embedding_dim
eidx = (i + 1) * emb.embedding_dim
if output is None:
output = F.linear(features[:, :, sidx:eidx], emb.weight)
else:
output += F.linear(features[:, :, sidx:eidx], emb.weight)
return output
else:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_tokens)
else:
return features
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
input[2] (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing)
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return (x, encoder_padding_mask, prev_output_tokens)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
input[2] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
# Note: incremental state is not yet supported
mt_task = False
if isinstance(input, tuple):
x = input[0]
encoder_out = input[1]
encoder_padding_mask = input[2]
incremental_state = None
mt_task = True
else:
x = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
# TODO: add back prev_self_attn_state, prev_attn_state,
# self_attn_padding_mask
prev_self_attn_state = None
prev_attn_state = None
self_attn_padding_mask = None
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
Embedding,
TransformerDecoderEmbedding,
TransformerDecoderLayer,
TransformerDecoderOutputLayer,
TransformerEncoderEmbedding,
TransformerEncoderLayer,
TransformerEncoderLayerNorm,
)
from fairseq.models import (
BaseFairseqModel,
FairseqDecoder,
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
base_architecture,
transformer_iwslt_de_en,
transformer_wmt_en_de_big,
)
from fairseq.modules import SinusoidalPositionalEmbedding
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("pipeline_parallel_transformer")
class PipelineParallelTransformerModel(BaseFairseqModel):
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
try:
from fairscale.nn import Pipe
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
super().__init__()
assert isinstance(encoder, FairseqEncoder)
assert isinstance(decoder, FairseqDecoder)
encoder_module_list = (
[encoder.embedding_layer]
+ list(encoder.encoder_layers)
+ [encoder.final_layer_norm]
)
self.num_encoder_modules = len(encoder_module_list)
decoder_module_list = (
[decoder.embedding_layer]
+ list(decoder.decoder_layers)
+ [decoder.decoder_output_layer]
)
self.num_decoder_modules = len(decoder_module_list)
module_list = encoder_module_list + decoder_module_list
self.devices = devices
self.model = Pipe(
nn.Sequential(*module_list),
balance=balance,
devices=devices,
chunks=chunks,
checkpoint=checkpoint,
)
self.encoder_max_positions = self.max_positions_helper(
encoder.embedding_layer, "max_source_positions"
)
self.decoder_max_positions = self.max_positions_helper(
decoder.embedding_layer, "max_target_positions"
)
self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
# Note: To be populated during inference
self.encoder = None
self.decoder = None
def forward(self, src_tokens, src_lengths, prev_output_tokens):
if self.training:
input_lst = [src_tokens, src_lengths, prev_output_tokens]
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
return self.model(input)
else:
assert self.encoder is not None and self.decoder is not None, (
"encoder and decoder need to be initialized by "
+ "calling the `prepare_for_inference_()` method"
)
encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple)
def prepare_for_inference_(self, args):
if self.encoder is not None and self.decoder is not None:
logger.info("Encoder and Decoder already initialized")
return
encoder_module_list = []
decoder_module_list = []
module_count = 0
for partition in self.model.partitions:
for module in partition:
if module_count < self.num_encoder_modules:
encoder_module_list.append(module)
else:
decoder_module_list.append(module)
module_count += 1
self.model = None
self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
self.decoder = TransformerDecoder(
args, None, None, decoder_module_list=decoder_module_list
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
help='Number of embedding layer chunks (enables more even distribution'
'of optimizer states across data parallel nodes'
'when using optimizer state sharding and'
'a big embedding vocabulary)')
# fmt: on
@classmethod
def build_model_base(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, "max_source_positions"):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, "max_target_positions"):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
assert embed_dim % num_embed_chunks == 0, (
f"Number of embedding chunks = {num_embed_chunks} should be "
+ f"divisible by the embedding dimension = {embed_dim}"
)
assert path is None or num_embed_chunks == 1, (
"Loading embedding from a path with number of embedding chunks > 1"
+ " is not yet supported"
)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
# if provided, load from preloaded dictionaries
if path:
emb = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
else:
embed_chunk_dim = embed_dim // num_embed_chunks
emb = nn.ModuleList()
for i in range(num_embed_chunks):
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
return emb
num_embed_chunks = args.num_embedding_chunks
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
"Not sharing decoder I/O embeddings is not yet supported with number of "
+ "embedding chunks > 1"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = build_embedding(
tgt_dict,
args.decoder_embed_dim,
args.decoder_embed_path,
num_embed_chunks,
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return (encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
@classmethod
def build_model(cls, args, task):
encoder, decoder = cls.build_model_base(args, task)
return PipelineParallelTransformerModel(
encoder=encoder,
decoder=decoder,
balance=utils.eval_str_list(args.pipeline_balance, type=int),
devices=utils.eval_str_list(args.pipeline_devices, type=int),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder_max_positions, self.decoder_max_positions)
def max_positions_helper(
self, embedding_layer, max_positions_field="max_source_positions"
):
"""Maximum input length supported by the encoder or decoder."""
if embedding_layer.embed_positions is None:
return getattr(embedding_layer, max_positions_field)
return min(
getattr(embedding_layer, max_positions_field),
embedding_layer.embed_positions.max_positions,
)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
return out.exp_() if not log_probs else out
# A Pipe() module returns a tuple of tensors as the output.
# In this case, the tuple has one element - the output tensor of logits
logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=False)
else:
return utils.softmax(logits, dim=-1, onnx_trace=False)
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder_max_positions
def load_state_dict(self, state_dict, strict=True, args=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
is_regular_transformer = not any("model.partitions" in k for k in state_dict)
if is_regular_transformer:
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
def convert_to_pipeline_parallel_state_dict(self, state_dict):
new_state_dict = self.state_dict()
encoder_layer_idx = 0
decoder_layer_idx = 0
encoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
decoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"encoder_attn.k_proj.weight",
"encoder_attn.k_proj.bias",
"encoder_attn.v_proj.weight",
"encoder_attn.v_proj.bias",
"encoder_attn.q_proj.weight",
"encoder_attn.q_proj.bias",
"encoder_attn.out_proj.weight",
"encoder_attn.out_proj.bias",
"encoder_attn_layer_norm.weight",
"encoder_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
for pid, partition in enumerate(self.model.partitions):
logger.info(f"Begin Partition {pid}")
for mid, module in enumerate(partition):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
if isinstance(module, TransformerEncoderLayer):
for suffix in encoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
encoder_layer_idx += 1
if isinstance(module, TransformerDecoderLayer):
for suffix in decoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
decoder_layer_idx += 1
if isinstance(module, TransformerEncoderLayerNorm):
if 'encoder.layer_norm.weight' in state_dict:
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
if isinstance(module, TransformerDecoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
if isinstance(module, TransformerDecoderOutputLayer):
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
# fmt: on
return new_state_dict
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
try:
from fairscale.nn import Pipe
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
if encoder_module_list is None:
embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens)
else:
emb_dim = embed_tokens.embedding_dim
final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
encoder_module_list = [embedding_layer] + layers + [final_layer_norm]
self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None
if self.use_pipeline:
encoder_balance = utils.eval_str_list(
args.pipeline_encoder_balance, type=int
)
encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int
)
assert sum(encoder_balance) == len(encoder_module_list), (
f"Sum of encoder_balance={encoder_balance} is not equal "
+ f"to num_encoder_modules={len(encoder_module_list)}"
)
self.model = Pipe(
module=nn.Sequential(*encoder_module_list),
balance=encoder_balance,
devices=encoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.embedding_layer = encoder_module_list[0]
self.encoder_layers = nn.Sequential(*encoder_module_list[1:-1])
self.final_layer_norm = encoder_module_list[-1]
def forward(self, src_tokens, src_lengths):
"""
Args:
input_tuple(
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
)
Returns:
output_tuple(
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- prev_output_tokens
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
)
"""
dummy_prev_output_tokens = torch.zeros(
1, dtype=src_tokens.dtype, device=src_tokens.device
)
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
encoder_out = self.model(input_tuple)
else:
encoder_embed_output_tuple = self.embedding_layer(input_tuple)
encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
encoder_out = self.final_layer_norm(encoder_layers_output)
# first element is the encoder output
# second element is the encoder padding mask
# the remaining elements of EncoderOut are not computed by
# the PipelineParallelTransformer
return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
0, new_order
)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(
0, new_order
)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_source_positions
return min(
self.embedding_layer.max_source_positions,
self.embedding_layer.embed_positions.max_positions,
)
class TransformerDecoder(FairseqDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
decoder_module_list=None,
):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
try:
from fairscale.nn import Pipe
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
if decoder_module_list is None:
embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
layers = [
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
decoder_output_layer = TransformerDecoderOutputLayer(
args, embed_tokens, dictionary
)
decoder_module_list = [embedding_layer] + layers + [decoder_output_layer]
self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None
if self.use_pipeline:
decoder_balance = utils.eval_str_list(
args.pipeline_decoder_balance, type=int
)
decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int
)
assert sum(decoder_balance) == len(decoder_module_list), (
f"Sum of decoder_balance={decoder_balance} is not equal "
+ f"to num_decoder_modules={len(decoder_module_list)}"
)
self.model = Pipe(
module=nn.Sequential(*decoder_module_list),
balance=decoder_balance,
devices=decoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.embedding_layer = decoder_module_list[0]
self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1])
self.decoder_output_layer = decoder_module_list[-1]
def forward(
self,
prev_output_tokens,
encoder_out=None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
input_tuple = (
encoder_out.encoder_out,
encoder_out.encoder_padding_mask,
prev_output_tokens,
)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
return (self.model(input_tuple),)
else:
embed_layer_output = self.embedding_layer(input_tuple)
state = self.decoder_layers(embed_layer_output)
return (self.decoder_output_layer(state),)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_target_positions
return min(
self.embedding_layer.max_target_positions,
self.embedding_layer.embed_positions.max_positions,
)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
)
def transformer_iwslt_de_en_dist(args):
transformer_iwslt_de_en(args)
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
)
def transformer_wmt_en_de_big_dist(args):
transformer_wmt_en_de_big(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder
from fairseq.models import FairseqEncoder, register_model, register_model_architecture
from fairseq.models.roberta import (
RobertaClassificationHead,
RobertaEncoder,
RobertaLMHead,
RobertaModel,
)
from fairseq.modules import LayerNorm, TransformerSentenceEncoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_roberta")
class ModelParallelRobertaModel(RobertaModel):
def __init__(self, args, encoder):
super().__init__(args, encoder)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
if getattr(args, "untie_weights_roberta", False):
raise NotImplementedError(
"--untie-weights-roberta is not supported in model parallel mode"
)
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs
):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
class ModelParallelRobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
x = copy_to_model_parallel_region(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight)
x = gather_from_model_parallel_region(x).contiguous()
x = x + self.bias
return x
class ModelParallelRobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
):
super().__init__()
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ModelParallelRobertaEncoder(FairseqEncoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None
self.sentence_encoder = ModelParallelTransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layerdrop=args.encoder_layerdrop,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=False,
apply_bert_init=False,
activation_fn=args.activation_fn,
)
self.lm_head = ModelParallelRobertaLMHead(
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.sentence_encoder.embed_tokens.weight,
)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
**unused
):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(
src_tokens, return_all_hiddens=return_all_hiddens
)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
inner_states, _ = self.sentence_encoder(
src_tokens,
last_state_only=not return_all_hiddens,
)
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
return features, {"inner_states": inner_states if return_all_hiddens else None}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
def base_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
def roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
def roberta_large_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch.nn as nn
import torch.nn.functional as F
from fairseq.model_parallel.modules import (
ModelParallelTransformerDecoderLayer,
ModelParallelTransformerEncoderLayer,
)
from fairseq.models import register_model
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
VocabParallelEmbedding,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_transformer")
class ModelParallelTransformerModel(TransformerModel):
"""
Model parallel Transformer model.
"""
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5)
nn.init.constant_(tensor[1], 0)
emb = VocabParallelEmbedding(
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
)
# if provided, load from preloaded dictionaries
if path:
raise NotImplementedError(
"Loading of embedding from path is not supported for model parallel"
)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return ModelParallelTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""
def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)
class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""
def build_decoder_layer(self, args, no_encoder_attn=False):
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
"Model parallel training currently requires --share-decoder-input-output-embed"
)
features = copy_to_model_parallel_region(features)
# project back to size of vocabulary
x = self.output_projection(features)
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
x = gather_from_model_parallel_region(x).contiguous()
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer_lm import TransformerLanguageModel
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
# make sure all arguments are present in older models
base_lm_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
if args.character_embeddings:
raise NotImplementedError(
"Character embeddings is not supported for model parallel"
)
elif args.adaptive_input:
raise NotImplementedError(
"Adaptive input is not supported for model parallel"
)
else:
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_input_dim
)
decoder = ModelParallelTransformerDecoder(
args,
task.target_dictionary,
embed_tokens,
no_encoder_attn=True,
)
return cls(decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(tensor[1], 0)
embed_tokens = VocabParallelEmbedding(
len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
)
return embed_tokens
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.character_filters = getattr(
args,
"character_filters",
"[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
)
args.character_embedding_dim = getattr(args, "character_embedding_dim", 4)
args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0)
args.add_bos_token = getattr(args, "add_bos_token", False)
@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
def transformer_lm_megatron(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture(
"model_parallel_transformer_lm", "transformer_lm_megatron_11b"
)
def transformer_lm_megatron_11b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .multihead_attention import ModelParallelMultiheadAttention
from .transformer_layer import (
ModelParallelTransformerEncoderLayer,
ModelParallelTransformerDecoderLayer,
)
from .transformer_sentence_encoder_layer import (
ModelParallelTransformerSentenceEncoderLayer,
)
from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder
__all__ = [
"ModelParallelMultiheadAttention",
"ModelParallelTransformerEncoderLayer",
"ModelParallelTransformerDecoderLayer",
"ModelParallelTransformerSentenceEncoder",
"ModelParallelTransformerSentenceEncoderLayer",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import Tensor, nn
try:
from fairseq.model_parallel.megatron.mpu import (
get_cuda_rng_tracker,
get_model_parallel_world_size,
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@with_incremental_state
class ModelParallelMultiheadAttention(nn.Module):
"""Model parallel Multi-headed attention.
This performs the Multi-headed attention over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
self_attention=False,
encoder_decoder_attention=False,
):
super().__init__()
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.model_parallel_size = get_model_parallel_world_size()
self.num_heads_partition = num_heads // self.model_parallel_size
assert (
self.num_heads_partition * self.model_parallel_size == num_heads
), "Number of heads must be divisible by model parallel size"
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"
self.k_proj = ColumnParallelLinear(
self.kdim, embed_dim, bias=bias, gather_output=False
)
self.v_proj = ColumnParallelLinear(
self.vdim, embed_dim, bias=bias, gather_output=False
)
self.q_proj = ColumnParallelLinear(
embed_dim, embed_dim, bias=bias, gather_output=False
)
self.out_proj = RowParallelLinear(
embed_dim, embed_dim, bias=bias, input_is_parallel=True
)
self.tpu = False
def prepare_for_tpu_(self, **kwargs):
self.tpu = True
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
**unused_kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = (
ModelParallelMultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
)
saved_state["prev_key"] = k.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_value"] = v.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [
bsz * self.num_heads_partition,
tgt_len,
src_len,
]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(
bsz, self.num_heads_partition, tgt_len, src_len
)
if not self.tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(
bsz * self.num_heads_partition, tgt_len, src_len
)
attn_weights_float = utils.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
with get_cuda_rng_tracker().fork():
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [
bsz * self.num_heads_partition,
tgt_len,
self.head_dim,
]
embed_dim_partition = embed_dim // self.model_parallel_size
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition)
attn = self.out_proj(attn)
# return attn_weights None to keep the return type same as single gpu multihead attention
# This will be deprecated.
attn_weights: Optional[Tensor] = None
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment