Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import checkpoint_utils
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
ConvTransformerModel,
convtransformer_espnet,
ConvTransformerEncoder,
)
from fairseq.models.speech_to_text.modules.augmented_memory_attention import (
augmented_memory,
SequenceEncoder,
AugmentedMemoryConvTransformerEncoder,
)
from torch import nn, Tensor
from typing import Dict, List
from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer
@register_model("convtransformer_simul_trans")
class SimulConvTransformerModel(ConvTransformerModel):
"""
Implementation of the paper:
SimulMT to SimulST: Adapting Simultaneous Text Translation to
End-to-End Simultaneous Speech Translation
https://www.aclweb.org/anthology/2020.aacl-main.58.pdf
"""
@staticmethod
def add_args(parser):
super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser)
parser.add_argument(
"--train-monotonic-only",
action="store_true",
default=False,
help="Only train monotonic attention",
)
@classmethod
def build_decoder(cls, args, task, embed_tokens):
tgt_dict = task.tgt_dict
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@register_model_architecture(
"convtransformer_simul_trans", "convtransformer_simul_trans_espnet"
)
def convtransformer_simul_trans_espnet(args):
convtransformer_espnet(args)
@register_model("convtransformer_augmented_memory")
@augmented_memory
class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel):
@classmethod
def build_encoder(cls, args):
encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args))
if getattr(args, "load_pretrained_encoder_from", None) is not None:
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_augmented_memory", "convtransformer_augmented_memory"
)
def augmented_memory_convtransformer_espnet(args):
convtransformer_espnet(args)
# ============================================================================ #
# Convtransformer
# with monotonic attention decoder
# with emformer encoder
# ============================================================================ #
class ConvTransformerEmformerEncoder(ConvTransformerEncoder):
def __init__(self, args):
super().__init__(args)
stride = self.conv_layer_stride(args)
trf_left_context = args.segment_left_context // stride
trf_right_context = args.segment_right_context // stride
context_config = [trf_left_context, trf_right_context]
self.transformer_layers = nn.ModuleList(
[
NoSegAugmentedMemoryTransformerEncoderLayer(
input_dim=args.encoder_embed_dim,
num_heads=args.encoder_attention_heads,
ffn_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
dropout_in_attn=args.dropout,
dropout_on_attn=args.dropout,
dropout_on_fc1=args.dropout,
dropout_on_fc2=args.dropout,
activation_fn=args.activation_fn,
context_config=context_config,
segment_size=args.segment_length,
max_memory_size=args.max_memory_size,
scaled_init=True, # TODO: use constant for now.
tanh_on_mem=args.amtrf_tanh_on_mem,
)
]
)
self.conv_transformer_encoder = ConvTransformerEncoder(args)
def forward(self, src_tokens, src_lengths):
encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device))
output = encoder_out["encoder_out"][0]
encoder_padding_masks = encoder_out["encoder_padding_mask"]
return {
"encoder_out": [output],
# This is because that in the original implementation
# the output didn't consider the last segment as right context.
"encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0
else [],
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
@staticmethod
def conv_layer_stride(args):
# TODO: make it configurable from the args
return 4
@register_model("convtransformer_emformer")
class ConvtransformerEmformer(SimulConvTransformerModel):
@staticmethod
def add_args(parser):
super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser)
parser.add_argument(
"--segment-length",
type=int,
metavar="N",
help="length of each segment (not including left context / right context)",
)
parser.add_argument(
"--segment-left-context",
type=int,
help="length of left context in a segment",
)
parser.add_argument(
"--segment-right-context",
type=int,
help="length of right context in a segment",
)
parser.add_argument(
"--max-memory-size",
type=int,
default=-1,
help="Right context for the segment.",
)
parser.add_argument(
"--amtrf-tanh-on-mem",
default=False,
action="store_true",
help="whether to use tanh on memory vector",
)
@classmethod
def build_encoder(cls, args):
encoder = ConvTransformerEmformerEncoder(args)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_emformer",
"convtransformer_emformer",
)
def convtransformer_emformer_base(args):
convtransformer_espnet(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer,
)
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
TransformerModel,
TransformerEncoder,
TransformerDecoder,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
transformer_vaswani_wmt_en_fr_big,
)
from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
TransformerMonotonicDecoderOut = NamedTuple(
"TransformerMonotonicDecoderOut",
[
("action", int),
("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]),
("step_list", Optional[List[Optional[Tensor]]]),
("encoder_out", Optional[Dict[str, List[Tensor]]]),
("encoder_padding_mask", Optional[Tensor]),
],
)
@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model("transformer_monotonic")
class TransformerModelSimulTrans(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
def _indices_from_states(self, states):
if type(states["indices"]["src"]) == list:
if next(self.parameters()).is_cuda:
tensor = torch.cuda.LongTensor
else:
tensor = torch.LongTensor
src_indices = tensor(
[states["indices"]["src"][: 1 + states["steps"]["src"]]]
)
tgt_indices = tensor(
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
)
else:
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices
class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
)
class TransformerMonotonicDecoder(TransformerDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
def pre_attention(
self,
prev_output_tokens,
encoder_out_dict: Dict[str, List[Tensor]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = encoder_out_dict["encoder_out"][0]
encoder_padding_mask = (
encoder_out_dict["encoder_padding_mask"][0]
if encoder_out_dict["encoder_padding_mask"]
and len(encoder_out_dict["encoder_padding_mask"]) > 0
else None
)
return x, encoder_out, encoder_padding_mask
def post_attention(self, x):
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x
def clear_cache(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
end_id: Optional[int] = None,
):
"""
Clear cache in the monotonic layers.
The cache is generated because of a forward pass of decode but no prediction.
end_id is the last idx of the layers
"""
if end_id is None:
end_id = len(self.layers)
for index, layer in enumerate(self.layers):
if index < end_id:
layer.prune_incremental_state(incremental_state)
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False, # unused
alignment_layer: Optional[int] = None, # unused
alignment_heads: Optional[int] = None, # unsed
):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# incremental_state = None
assert encoder_out is not None
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
attn_list: List[Optional[Dict[str, Tensor]]] = []
step_list: List[Optional[Tensor]] = []
for i, layer in enumerate(self.layers):
x, attn, _ = layer(
x=x,
encoder_out=encoder_outs,
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
)
inner_states.append(x)
attn_list.append(attn)
if incremental_state is not None:
curr_steps = layer.get_head_steps(incremental_state)
step_list.append(curr_steps)
if_online = incremental_state["online"]["only"]
assert if_online is not None
if if_online.to(torch.bool):
# Online indicates that the encoder states are still changing
assert attn is not None
assert curr_steps is not None
p_choose = (
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
)
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
src = incremental_state["steps"]["src"]
assert src is not None
if (new_steps >= src).any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
self.clear_cache(incremental_state, i + 1)
return x, TransformerMonotonicDecoderOut(
action=0,
attn_list=None,
step_list=None,
encoder_out=None,
encoder_padding_mask=None,
)
x = self.post_attention(x)
return x, TransformerMonotonicDecoderOut(
action=1,
attn_list=attn_list,
step_list=step_list,
encoder_out=encoder_out,
encoder_padding_mask=encoder_padding_mask,
)
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = incremental_state[
"fastest_step"
].index_select(0, new_order)
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_architecture(args):
base_architecture(args)
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
base_monotonic_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
_,
) = registry.setup_registry("--simul-type")
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)
from functools import partial
import torch
from torch import Tensor
import math
import torch.nn.functional as F
from . import register_monotonic_attention
from .monotonic_multihead_attention import (
MonotonicMultiheadAttentionWaitK,
MonotonicMultiheadAttentionHardAligned,
MonotonicMultiheadAttentionInfiniteLookback,
)
from typing import Dict, Optional
from examples.simultaneous_translation.utils import p_choose_strategy
def fixed_pooling_monotonic_attention(monotonic_attention):
def create_model(monotonic_attention, klass):
class FixedStrideMonotonicAttention(monotonic_attention):
def __init__(self, args):
self.waitk_lagging = 0
self.num_heads = 0
self.noise_mean = 0.0
self.noise_var = 0.0
super().__init__(args)
self.pre_decision_type = args.fixed_pre_decision_type
self.pre_decision_ratio = args.fixed_pre_decision_ratio
self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold
if self.pre_decision_ratio == 1:
return
self.strategy = args.simul_type
if args.fixed_pre_decision_type == "average":
self.pooling_layer = torch.nn.AvgPool1d(
kernel_size=self.pre_decision_ratio,
stride=self.pre_decision_ratio,
ceil_mode=True,
)
elif args.fixed_pre_decision_type == "last":
def last(key):
if key.size(2) < self.pre_decision_ratio:
return key
else:
k = key[
:,
:,
self.pre_decision_ratio - 1 :: self.pre_decision_ratio,
].contiguous()
if key.size(-1) % self.pre_decision_ratio != 0:
k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous()
return k
self.pooling_layer = last
else:
raise NotImplementedError
@staticmethod
def add_args(parser):
super(
FixedStrideMonotonicAttention, FixedStrideMonotonicAttention
).add_args(parser)
parser.add_argument(
"--fixed-pre-decision-ratio",
type=int,
required=True,
help=(
"Ratio for the fixed pre-decision,"
"indicating how many encoder steps will start"
"simultaneous decision making process."
),
)
parser.add_argument(
"--fixed-pre-decision-type",
default="average",
choices=["average", "last"],
help="Pooling type",
)
parser.add_argument(
"--fixed-pre-decision-pad-threshold",
type=float,
default=0.3,
help="If a part of the sequence has pad"
",the threshold the pooled part is a pad.",
)
def insert_zeros(self, x):
bsz_num_heads, tgt_len, src_len = x.size()
stride = self.pre_decision_ratio
weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0))
x_upsample = F.conv_transpose1d(
x.view(-1, src_len).unsqueeze(1),
weight,
stride=stride,
padding=0,
)
return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1)
def p_choose_waitk(
self, query, key, key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
"""
query: bsz, tgt_len
key: bsz, src_len
key_padding_mask: bsz, src_len
"""
if incremental_state is not None:
# Retrieve target length from incremental states
# For inference the length of query is always 1
tgt = incremental_state["steps"]["tgt"]
assert tgt is not None
tgt_len = int(tgt)
else:
tgt_len, bsz, _ = query.size()
src_len, bsz, _ = key.size()
p_choose = torch.ones(bsz, tgt_len, src_len).to(query)
p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1)
p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1)
if incremental_state is not None:
p_choose = p_choose[:, -1:]
tgt_len = 1
# Extend to each head
p_choose = (
p_choose.contiguous()
.unsqueeze(1)
.expand(-1, self.num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, src_len)
)
return p_choose
def p_choose(
self,
query: Optional[Tensor],
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
assert key is not None
assert query is not None
src_len = key.size(0)
tgt_len = query.size(0)
batch_size = query.size(1)
if self.pre_decision_ratio == 1:
if self.strategy == "waitk":
return p_choose_strategy.waitk(
query,
key,
self.waitk_lagging,
self.num_heads,
key_padding_mask,
incremental_state=incremental_state,
)
else: # hard_aligned or infinite_lookback
q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
return p_choose_strategy.hard_aligned(
q_proj,
k_proj,
attn_energy,
self.noise_mean,
self.noise_var,
self.training
)
key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2)
if key_padding_mask is not None:
key_padding_mask_pool = (
self.pooling_layer(key_padding_mask.unsqueeze(0).float())
.squeeze(0)
.gt(self.pre_decision_pad_threshold)
)
# Make sure at least one element is not pad
key_padding_mask_pool[:, 0] = 0
else:
key_padding_mask_pool = None
if incremental_state is not None:
# The floor instead of ceil is used for inference
# But make sure the length key_pool at least 1
if (
max(1, math.floor(key.size(0) / self.pre_decision_ratio))
) < key_pool.size(0):
key_pool = key_pool[:-1]
if key_padding_mask_pool is not None:
key_padding_mask_pool = key_padding_mask_pool[:-1]
p_choose_pooled = self.p_choose_waitk(
query,
key_pool,
key_padding_mask_pool,
incremental_state=incremental_state,
)
# Upsample, interpolate zeros
p_choose = self.insert_zeros(p_choose_pooled)
if p_choose.size(-1) < src_len:
# Append zeros if the upsampled p_choose is shorter than src_len
p_choose = torch.cat(
[
p_choose,
torch.zeros(
p_choose.size(0),
tgt_len,
src_len - p_choose.size(-1)
).to(p_choose)
],
dim=2
)
else:
# can be larger than src_len because we used ceil before
p_choose = p_choose[:, :, :src_len]
p_choose[:, :, -1] = p_choose_pooled[:, :, -1]
assert list(p_choose.size()) == [
batch_size * self.num_heads,
tgt_len,
src_len,
]
return p_choose
FixedStrideMonotonicAttention.__name__ = klass.__name__
return FixedStrideMonotonicAttention
return partial(create_model, monotonic_attention)
@register_monotonic_attention("waitk_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionWaitK)
class MonotonicMultiheadAttentionWaitkFixedStride:
pass
@register_monotonic_attention("hard_aligned_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionHardAligned)
class MonotonicMultiheadAttentionHardFixedStride:
pass
@register_monotonic_attention("infinite_lookback_fixed_pre_decision")
@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionInfiniteLookback)
class MonotonicMultiheadAttentionInfiniteLookbackFixedStride:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import Tensor
import torch.nn as nn
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
lengths_to_mask,
)
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from . import register_monotonic_attention
from typing import Dict, Optional
from examples.simultaneous_translation.utils import p_choose_strategy
@with_incremental_state
class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
self.noise_type = args.noise_type
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def p_choose(self, *args):
raise NotImplementedError
def input_projections(self, *args):
raise NotImplementedError
def attn_energy(
self, q_proj, k_proj, key_padding_mask=None, attn_mask=None
):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = (
torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_energy += attn_mask
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
if key_padding_mask is not None and key_padding_mask[:, -1].any():
# right padding
batch_size = key_padding_mask.size(0)
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0)
src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(
batch_size, self.num_heads
).contiguous().view(-1, 1)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals)
else:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
alpha[:, :, -1] = residuals
if torch.isnan(alpha).any():
# Something is wrong
raise RuntimeError("NaN in alpha.")
return alpha
def expected_alignment_infer(
self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# TODO modify this function
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
incremental_state: dict
encodencoder_padding_mask: bsz * src_len
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"head_step",
p_choose.new_zeros([bsz, self.num_heads]).long()
)
assert prev_monotonic_step is not None
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if encoder_padding_mask is not None:
src_lengths = src_len - \
encoder_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = 0
if encoder_padding_mask is not None:
if encoder_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
p_choose_i = 1
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
monotonic_cache["head_step"] = new_monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = (
p_choose
.new_zeros([bsz * self.num_heads, src_len])
.scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps)
.view(bsz * self.num_heads, 1),
0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
return self.get_incremental_state(
incremental_state,
'monotonic',
) or {}
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
def v_proj_output(self, value):
raise NotImplementedError
def forward(
self, query, key, value,
key_padding_mask=None, attn_mask=None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights=True, static_kv=False
):
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(
query, key, key_padding_mask, incremental_state,
)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state)
else:
alpha = self.expected_alignment_train(
p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value,
key_padding_mask, attn_mask,
incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"alpha": alpha,
"beta": beta,
"p_choose": p_choose,
}
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHardAligned(
MonotonicAttention, MultiheadAttention
):
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.v_in_proj = {"output": self.v_proj}
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def attn_energy(
self, q_proj: Optional[Tensor], k_proj: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None
):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
assert q_proj is not None # Optional[Tensor] annotations in the signature above are to make the JIT compiler happy
assert k_proj is not None
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = (
torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_energy += attn_mask
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
if key_padding_mask is not None and key_padding_mask[:, -1].any():
# right padding
batch_size = key_padding_mask.size(0)
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0)
src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(
batch_size, self.num_heads
).contiguous().view(-1, 1)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals)
else:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
alpha[:, :, -1] = residuals
if torch.isnan(alpha).any():
# Something is wrong
raise RuntimeError("NaN in alpha.")
return alpha
def expected_alignment_infer(
self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# TODO modify this function
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
incremental_state: dict
encodencoder_padding_mask: bsz * src_len
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"head_step",
p_choose.new_zeros([bsz, self.num_heads]).long()
)
assert prev_monotonic_step is not None
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if encoder_padding_mask is not None:
src_lengths = src_len - \
encoder_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = torch.ones(bsz, 1).to(prev_monotonic_step) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = torch.tensor(0)
if encoder_padding_mask is not None:
if encoder_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
p_choose_i = torch.tensor(1)
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
monotonic_cache["head_step"] = new_monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = (
p_choose
.new_zeros([bsz * self.num_heads, src_len])
.scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps)
.view(bsz * self.num_heads, 1),
0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
maybe_incremental_state = self.get_incremental_state(
incremental_state,
'monotonic',
)
if maybe_incremental_state is None:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
return typed_empty_dict
else:
return maybe_incremental_state
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
def forward(
self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False,
):
assert query is not None
assert value is not None
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(
query, key, key_padding_mask, incremental_state,
)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state)
else:
alpha = self.expected_alignment_train(
p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value,
key_padding_mask, attn_mask,
incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
assert v_proj is not None
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"alpha": alpha,
"beta": beta,
"p_choose": p_choose,
}
def input_projections(self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], name: str):
"""
Prepare inputs for multihead attention
============================================================
Expected input size
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
name: monotonic or soft
"""
if query is not None:
bsz = query.size(1)
q = self.q_proj(query)
q *= self.scaling
q = q.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_proj(key)
k = k.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_proj(value)
v = v.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
v = None
return q, k, v
def p_choose(
self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
============================================================
Expected input size
query: bsz, tgt_len, embed_dim
key: bsz, src_len, embed_dim
value: bsz, src_len, embed_dim
key_padding_mask: bsz, src_len
attn_mask: bsz, src_len
query: bsz, tgt_len, embed_dim
"""
# prepare inputs
q_proj, k_proj, _ = self.input_projections(
query, key, None, "monotonic"
)
# attention energy
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
return p_choose_strategy.hard_aligned(q_proj, k_proj, attn_energy, self.noise_mean, self.noise_var, self.training)
def expected_attention(self, alpha, *args):
"""
For MMA-H, beta = alpha
"""
return alpha
def v_proj_output(self, value):
_, _, v_proj = self.input_projections(None, None, value, "output")
return v_proj
@register_monotonic_attention("infinite_lookback")
class MonotonicMultiheadAttentionInfiniteLookback(
MonotonicMultiheadAttentionHardAligned
):
def __init__(self, args):
super().__init__(args)
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(
self, alpha, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
q, k, _ = self.input_projections(query, key, None, "soft")
soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask)
assert list(soft_energy.size()) == \
[bsz, self.num_heads, tgt_len, src_len]
soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
monotonic_cache = self._get_monotonic_buffer(incremental_state)
head_step = monotonic_cache["head_step"]
assert head_step is not None
monotonic_length = head_step + 1
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_length += step_offset
mask = lengths_to_mask(
monotonic_length.view(-1),
soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~mask.to(torch.bool), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
else:
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) + self.eps
inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2))
beta = (
exp_soft_energy
* torch.cumsum(inner_items.flip(dims=[2]), dim=2)
.flip(dims=[2])
)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
beta = beta.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0)
beta = beta / beta.sum(dim=3, keepdim=True)
beta = beta.view(bsz * self.num_heads, tgt_len, src_len)
beta = self.dropout_module(beta)
if torch.isnan(beta).any():
# Something is wrong
raise RuntimeError("NaN in beta.")
return beta
@register_monotonic_attention("waitk")
class MonotonicMultiheadAttentionWaitK(
MonotonicMultiheadAttentionInfiniteLookback
):
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, (
f"Lagging has to been larger than 0, get {self.waitk_lagging}."
)
@staticmethod
def add_args(parser):
super(
MonotonicMultiheadAttentionWaitK,
MonotonicMultiheadAttentionWaitK,
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait K lagging"
)
def p_choose(
self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
"""
query: bsz, tgt_len
key: bsz, src_len
key_padding_mask: bsz, src_len
"""
return p_choose_strategy.waitk(query, key, self.waitk_lagging, self.num_heads, key_padding_mask, incremental_state)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
from typing import Dict, List, Optional
import torch
from torch import Tensor
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(
args,
no_encoder_attn=True,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
assert args.simul_type is not None, "A --simul-type is needed."
self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm(
self.embed_dim, export=getattr(args, "char_inputs", False)
)
def get_head_steps(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
return self.encoder_attn._get_monotonic_buffer(incremental_state).get(
"head_step"
)
def prune_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
input_buffer = self.self_attn._get_input_buffer(incremental_state)
for key in ["prev_key", "prev_value"]:
input_buffer_key = input_buffer[key]
assert input_buffer_key is not None
if input_buffer_key.size(2) > 1:
input_buffer[key] = input_buffer_key[:, :, :-1, :]
else:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
input_buffer = typed_empty_dict
break
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, input_buffer)
def get_steps(self, incremental_state):
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
assert self.encoder_attn is not None
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("examples.simultaneous_translation.utils." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def calc_mean_invstddev(feature):
if len(feature.size()) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = feature.mean(0)
var = feature.var(0)
# avoid division by ~zero
eps = 1e-8
if (var < eps).any():
return mean, 1.0 / (torch.sqrt(var) + eps)
return mean, 1.0 / torch.sqrt(var)
def apply_mv_norm(features):
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
# and normalization is not possible, so return the item as it is
if features.size(0) < 2:
return features
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False):
"""
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
Args:
lengths: a (B, )-shaped tensor
Return:
max_length: maximum length of B sequences
encoder_padding_mask: a (max_length, B) binary mask, where
[t, b] = 0 for t < lengths[b] and 1 otherwise
TODO:
kernelize this function if benchmarking shows this function is slow
"""
max_lengths = torch.max(lengths).item()
bsz = lengths.size(0)
encoder_padding_mask = torch.arange(
max_lengths
).to( # a (T, ) tensor with [0, ..., T-1]
lengths.device
).view( # move to the right device
1, max_lengths
).expand( # reshape to (1, T)-shaped tensor
bsz, -1
) >= lengths.view( # expand to (B, T)-shaped tensor
bsz, 1
).expand(
-1, max_lengths
)
if not batch_first:
return encoder_padding_mask.t(), max_lengths
else:
return encoder_padding_mask, max_lengths
def encoder_padding_mask_to_lengths(
encoder_padding_mask, max_lengths, batch_size, device
):
"""
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
Conventionally, encoder output contains a encoder_padding_mask, which is
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
[b] denotes the valid length of b-th sequence
Args:
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
indicating all are valid
Return:
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
number of valid elements of b-th sequence
max_lengths: maximum length of all sequence, if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(0)
batch_size: batch size; if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(1)
device: which device to put the result on
"""
if encoder_padding_mask is None:
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
Implementing exclusive cumprod.
There is cumprod in pytorch, however there is no exclusive mode.
cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
"""
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
)
if dim == 0:
return return_tensor[:-1]
elif dim == 1:
return return_tensor[:, :-1]
elif dim == 2:
return return_tensor[:, :, :-1]
else:
raise RuntimeError("Cumprod on dimension 3 and more is not implemented")
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
An implementation of cumprod to prevent precision issue.
cumprod(x)
= [x1, x1x2, x1x2x3, ....]
= [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
= exp(cumsum(log(x)))
"""
if (tensor + eps < 0).any().item():
raise RuntimeError(
"Safe cumprod can only take non-negative tensors as input."
"Consider use torch.cumprod if you want to calculate negative values."
)
log_tensor = torch.log(tensor + eps)
cumsum_log_tensor = torch.cumsum(log_tensor, dim)
exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
return exp_cumsum_log_tensor
def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False):
"""
Convert a tensor of lengths to mask
For example, lengths = [[2, 3, 4]], max_len = 5
mask =
[[1, 1, 1],
[1, 1, 1],
[0, 1, 1],
[0, 0, 1],
[0, 0, 0]]
"""
assert len(lengths.size()) <= 2
if len(lengths) == 2:
if dim == 1:
lengths = lengths.t()
lengths = lengths
else:
lengths = lengths.unsqueeze(1)
# lengths : batch_size, 1
lengths = lengths.view(-1, 1)
batch_size = lengths.size(0)
# batch_size, max_len
mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths
if negative_mask:
mask = ~mask
if dim == 0:
# max_len, batch_size
mask = mask.t()
return mask
def moving_sum(x, start_idx: int, end_idx: int):
"""
From MONOTONIC CHUNKWISE ATTENTION
https://arxiv.org/pdf/1712.05382.pdf
Equation (18)
x = [x_1, x_2, ..., x_N]
MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m
for n in {1, 2, 3, ..., N}
x : src_len, batch_size
start_idx : start idx
end_idx : end idx
Example
src_len = 5
batch_size = 3
x =
[[ 0, 5, 10],
[ 1, 6, 11],
[ 2, 7, 12],
[ 3, 8, 13],
[ 4, 9, 14]]
MovingSum(x, 3, 1) =
[[ 0, 5, 10],
[ 1, 11, 21],
[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39]]
MovingSum(x, 1, 3) =
[[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39],
[ 7, 17, 27],
[ 4, 9, 14]]
"""
assert start_idx > 0 and end_idx > 0
assert len(x.size()) == 2
src_len, batch_size = x.size()
# batch_size, 1, src_len
x = x.t().unsqueeze(1)
# batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
moving_sum = (
torch.nn.functional.conv1d(
x, moving_sum_weight, padding=start_idx + end_idx - 1
)
.squeeze(1)
.t()
)
moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1)
return moving_sum
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class LatencyMetric(object):
@staticmethod
def length_from_padding_mask(padding_mask, batch_first: bool = False):
dim = 1 if batch_first else 0
return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True)
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
if batch_first:
# convert to batch_last
delays = delays.t()
src_lens = src_lens.t()
tgt_len, bsz = delays.size()
_, bsz_1 = src_lens.size()
if target_padding_mask is not None:
target_padding_mask = target_padding_mask.t()
tgt_len_1, bsz_2 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
assert bsz == bsz_1
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([1, bsz]).float()
else:
# 1, batch_size
tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float()
delays = delays.masked_fill(target_padding_mask, 0)
return delays, src_lens, tgt_lens, target_padding_mask
def __call__(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
Expected sizes:
delays: tgt_len, batch_size
src_lens: 1, batch_size
target_padding_mask: tgt_len, batch_size
"""
raise NotImplementedError
class AverageProportion(LatencyMetric):
"""
Function to calculate Average Proportion from
Can neural machine translation do simultaneous translation?
(https://arxiv.org/abs/1606.02012)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else:
AP = torch.sum(delays, dim=0, keepdim=True)
AP = AP / (src_lens * tgt_lens)
return AP
class AverageLagging(LatencyMetric):
"""
Function to calculate Average Lagging from
STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework
(https://arxiv.org/abs/1810.08398)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
Where
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens
lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
return AL
class DifferentiableAverageLagging(LatencyMetric):
"""
Function to calculate Differentiable Average Lagging from
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
(https://arxiv.org/abs/1906.05218)
Delays are monotonic steps, range from 0 to src_len-1.
(In the original paper thery are from 1 to src_len)
Give src x tgt y, AP is calculated as:
DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma
Where
delays'_i =
1. delays_i if i == 1
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
gamma = tgt_lens / src_lens
new_delays = torch.zeros_like(delays)
for i in range(delays.size(0)):
if i == 0:
new_delays[i] = delays[i]
else:
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0),
],
dim=0,
).max(dim=0)[0]
DAL = (
new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens
return DAL
class LatencyMetricVariance(LatencyMetric):
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = True,
start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
# convert to batch_last
bsz, num_heads_x_layers, tgt_len = delays.size()
bsz_1, _ = src_lens.size()
assert bsz == bsz_1
if target_padding_mask is not None:
bsz_2, tgt_len_1 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float()
else:
# batch_size, 1
tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float()
delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0)
return delays, src_lens, tgt_lens, target_padding_mask
class VarianceDelay(LatencyMetricVariance):
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
delays : bsz, num_heads_x_layers, tgt_len
src_lens : bsz, 1
target_lens : bsz, 1
target_padding_mask: bsz, tgt_len or None
"""
if delays.size(1) == 1:
return delays.new_zeros([1])
variance_delays = delays.var(dim=1)
if target_padding_mask is not None:
variance_delays.masked_fill_(target_padding_mask, 0)
return variance_delays.sum(dim=1, keepdim=True) / tgt_lens
class LatencyInference(object):
def __init__(self, start_from_zero=True):
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.start_from_zero = start_from_zero
def __call__(self, monotonic_step, src_lens):
"""
monotonic_step range from 0 to src_len. src_len means eos
delays: bsz, tgt_len
src_lens: bsz, 1
"""
if not self.start_from_zero:
monotonic_step -= 1
src_lens = src_lens
delays = monotonic_step.view(
monotonic_step.size(0), -1, monotonic_step.size(-1)
).max(dim=1)[0]
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays
).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
delays.float(),
src_lens.float(),
target_padding_mask=None,
batch_first=True,
start_from_zero=True,
).t()
return return_dict
class LatencyTraining(object):
def __init__(
self,
avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
self.avg_type = avg_type
self.var_type = var_type
self.stay_on_last_token = stay_on_last_token
self.average_method = average_method
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.variance_calculator = {
"variance_delay": VarianceDelay(),
}
def expected_delays_from_attention(
self, attention, source_padding_mask=None, target_padding_mask=None
):
if type(attention) == list:
# bsz, num_heads, tgt_len, src_len
bsz, num_heads, tgt_len, src_len = attention[0].size()
attention = torch.cat(attention, dim=1)
bsz, num_heads_x_layers, tgt_len, src_len = attention.size()
# bsz * num_heads * num_layers, tgt_len, src_len
attention = attention.view(-1, tgt_len, src_len)
else:
# bsz * num_heads * num_layers, tgt_len, src_len
bsz, tgt_len, src_len = attention.size()
num_heads_x_layers = 1
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
.type_as(attention)
)
if source_padding_mask is not None:
src_offset = (
source_padding_mask.type_as(attention)
.sum(dim=1, keepdim=True)
.expand(bsz, num_heads_x_layers)
.contiguous()
.view(-1, 1)
)
src_lens = src_len - src_offset
if source_padding_mask[:, 0].any():
# Pad left
src_offset = src_offset.view(-1, 1, 1)
steps = steps - src_offset
steps = steps.masked_fill(steps <= 0, 0)
else:
src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
def avg_loss(self, expected_delays, src_lens, target_padding_mask):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
)
if self.average_method == "average":
# bsz * tgt_len
expected_delays = expected_delays.mean(dim=1)
elif self.average_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_delays, dim=1)
expected_delays = torch.sum(expected_delays * weights, dim=1)
elif self.average_method == "max":
# bsz * num_heads_x_num_layers, tgt_len
expected_delays = expected_delays.max(dim=1)[0]
else:
raise RuntimeError(f"{self.average_method} is not supported")
src_lens = src_lens.view(bsz, -1)[:, :1]
target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0]
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
# bsz * num_heads_x_num_layers, 1
return self.avg_weight * average_delays.sum()
else:
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")
return self.var_weight * variance_delays.sum()
else:
return 0.0
def loss(self, attention, source_padding_mask=None, target_padding_mask=None):
expected_delays, src_lens = self.expected_delays_from_attention(
attention, source_padding_mask, target_padding_mask
)
latency_loss = 0
latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask)
latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask)
return latency_loss
from typing import Optional, Dict
from torch import Tensor
import torch
def waitk(
query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
if incremental_state is not None:
# Retrieve target length from incremental states
# For inference the length of query is always 1
tgt_len = incremental_state["steps"]["tgt"]
assert tgt_len is not None
tgt_len = int(tgt_len)
else:
tgt_len, bsz, _ = query.size()
max_src_len, bsz, _ = key.size()
if max_src_len < waitk_lagging:
if incremental_state is not None:
tgt_len = 1
return query.new_zeros(
bsz * num_heads, tgt_len, max_src_len
)
# Assuming the p_choose looks like this for wait k=3
# src_len = 6, tgt_len = 5
# [0, 0, 1, 0, 0, 0, 0]
# [0, 0, 0, 1, 0, 0, 0]
# [0, 0, 0, 0, 1, 0, 0]
# [0, 0, 0, 0, 0, 1, 0]
# [0, 0, 0, 0, 0, 0, 1]
# linearize the p_choose matrix:
# [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...]
# The indices of linearized matrix that equals 1 is
# 2 + 6 * 0
# 3 + 6 * 1
# ...
# n + src_len * n + k - 1 = n * (src_len + 1) + k - 1
# n from 0 to tgt_len - 1
#
# First, generate the indices (activate_indices_offset: bsz, tgt_len)
# Second, scatter a zeros tensor (bsz, tgt_len * src_len)
# with activate_indices_offset
# Third, resize the tensor to (bsz, tgt_len, src_len)
activate_indices_offset = (
(
torch.arange(tgt_len) * (max_src_len + 1)
+ waitk_lagging - 1
)
.unsqueeze(0)
.expand(bsz, tgt_len)
.to(query)
.long()
)
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# Left padding
activate_indices_offset += (
key_padding_mask.sum(dim=1, keepdim=True)
)
# Need to clamp the indices that are too large
activate_indices_offset = (
activate_indices_offset
.clamp(
0,
min(
[
tgt_len,
max_src_len - waitk_lagging + 1
]
) * max_src_len - 1
)
)
p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query)
p_choose = p_choose.scatter(
1,
activate_indices_offset,
1.0
).view(bsz, tgt_len, max_src_len)
if incremental_state is not None:
p_choose = p_choose[:, -1:]
tgt_len = 1
# Extend to each head
p_choose = (
p_choose.contiguous()
.unsqueeze(1)
.expand(-1, num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, max_src_len)
)
return p_choose
def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
"""
noise = 0
if training:
# add noise here to encourage discretness
noise = (
torch.normal(noise_mean, noise_var, attn_energy.size())
.type_as(attn_energy)
.to(attn_energy.device)
)
p_choose = torch.sigmoid(attn_energy + noise)
_, _, tgt_len, src_len = p_choose.size()
# p_choose: bsz * self.num_heads, tgt_len, src_len
return p_choose.view(-1, tgt_len, src_len)
### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned.
# Speech Recognition
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
## Additional dependencies
On top of main fairseq dependencies there are couple more additional requirements.
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data
```
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
```
## Training librispeech data
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
```
## Inference for librispeech
`$SET` can be `test_clean` or `test_other`
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
```
## Inference for librispeech
```
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER
## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components
[flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* flashlight-style Conv/GLU model
* flashlight's beam search decoder
To use these, follow the instructions on [this page](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) to install python bindings.
## Training librispeech data (flashlight style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (flashlight decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (flashlight decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
from . import criterions, models, tasks # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
group = parser.add_argument_group("ASG Loss")
group.add_argument(
"--asg-transitions-init",
help="initial diagonal value of transition matrix",
type=float,
default=0.0,
)
group.add_argument(
"--max-replabel", help="maximum # of replabels", type=int, default=2
)
group.add_argument(
"--linseg-updates",
help="# of training updates to use LinSeg initialization",
type=int,
default=0,
)
group.add_argument(
"--hide-linseg-messages",
help="hide messages about LinSeg initialization",
action="store_true",
)
def __init__(
self,
task,
silence_token,
asg_transitions_init,
max_replabel,
linseg_updates,
hide_linseg_messages,
):
from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode
super().__init__(task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
self.silence = (
self.tgt_dict.index(silence_token)
if silence_token in self.tgt_dict
else None
)
self.max_replabel = max_replabel
num_labels = len(self.tgt_dict)
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
self.asg.trans = torch.nn.Parameter(
asg_transitions_init * torch.eye(num_labels), requires_grad=True
)
self.linseg_progress = torch.nn.Parameter(
torch.tensor([0], dtype=torch.int), requires_grad=False
)
self.linseg_maximum = linseg_updates
self.linseg_message_state = "none" if hide_linseg_messages else "start"
@classmethod
def build_criterion(cls, args, task):
return cls(
task,
args.silence_token,
args.asg_transitions_init,
args.max_replabel,
args.linseg_updates,
args.hide_linseg_messages,
)
def linseg_step(self):
if not self.training:
return False
if self.linseg_progress.item() < self.linseg_maximum:
if self.linseg_message_state == "start":
print("| using LinSeg to initialize ASG")
self.linseg_message_state = "finish"
self.linseg_progress.add_(1)
return True
elif self.linseg_message_state == "finish":
print("| finished LinSeg initialization")
self.linseg_message_state = "none"
return False
def replace_eos_with_silence(self, tgt):
if tgt[-1] != self.eos:
return tgt
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
return tgt[:-1]
else:
return tgt[:-1] + [self.silence]
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
B = emissions.size(0)
T = emissions.size(1)
device = emissions.device
target = torch.IntTensor(B, T)
target_size = torch.IntTensor(B)
using_linseg = self.linseg_step()
for b in range(B):
initial_target_size = sample["target_lengths"][b].item()
if initial_target_size == 0:
raise ValueError("target size cannot be zero")
tgt = sample["target"][b, :initial_target_size].tolist()
tgt = self.replace_eos_with_silence(tgt)
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
tgt = tgt[:T]
if using_linseg:
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
target[b][: len(tgt)] = torch.IntTensor(tgt)
target_size[b] = len(tgt)
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
if reduce:
loss = torch.sum(loss)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / nsentences,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return agg_output
import importlib
import os
# ASG loss requires flashlight bindings
files_to_skip = set()
try:
import flashlight.lib.sequence.criterion
except ImportError:
files_to_skip.add("ASG_loss.py")
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
criterion_name = file[: file.find(".py")]
importlib.import_module(
"examples.speech_recognition.criterions." + criterion_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
target = target.view(-1)
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
if not hasattr(lprobs, "batch_first"):
logging.warning(
"ERROR: we need to know whether "
"batch first for the net output; "
"you need to set batch_first attribute for the return value of "
"model.get_normalized_probs. Now, we assume this is true, but "
"in the future, we will raise exception instead. "
)
batch_first = getattr(lprobs, "batch_first", True)
if not batch_first:
lprobs = lprobs.transpose(0, 1)
# N, T, D -> N * T, D
lprobs = lprobs.view(-1, lprobs.size(-1))
loss = F.nll_loss(
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
)
return lprobs, loss
def get_logging_output(self, sample, target, lprobs, loss):
target = target.view(-1)
mask = target != self.padding_idx
correct = torch.sum(
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"correct": utils.item(correct.data),
"total": utils.item(total.data),
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
}
return sample_size, logging_output
def forward(self, model, sample, reduction="sum", log_probs=True):
"""Computes the cross entropy with accuracy metric for the given sample.
This is similar to CrossEntropyCriterion in fairseq, but also
computes accuracy metrics as part of logging
Args:
logprobs (Torch.tensor) of shape N, T, D i.e.
batchsize, timesteps, dimensions
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
Returns:
tuple: With three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
TODO:
* Currently this Criterion will only work with LSTMEncoderModels or
FairseqModels which have decoder, or Models which return TorchTensor
as net_output.
We need to make a change to support all FairseqEncoder models.
"""
net_output = model(**sample["net_input"])
target = model.get_targets(sample, net_output)
lprobs, loss = self.compute_loss(
model, net_output, target, reduction, log_probs
)
sample_size, logging_output = self.get_logging_output(
sample, target, lprobs, loss
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
total_sum = sum(log.get("total", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
# if args.sentence_avg, then sample_size is nsentences, then loss
# is per-sentence loss; else sample_size is ntokens, the loss
# becomes per-output token loss
"ntokens": ntokens,
"nsentences": nsentences,
"nframes": nframes,
"sample_size": sample_size,
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
"correct": correct_sum,
"total": total_sum,
# total is the number of validate tokens
}
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
# loss: per output token loss
# nll_loss: per sentence loss
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .asr_dataset import AsrDataset
__all__ = [
"AsrDataset",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import FairseqDataset
from . import data_utils
from .collaters import Seq2SeqCollater
class AsrDataset(FairseqDataset):
"""
A dataset representing speech and corresponding transcription.
Args:
aud_paths: (List[str]): A list of str with paths to audio files.
aud_durations_ms (List[int]): A list of int containing the durations of
audio files.
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
of target transcriptions.
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
ids (List[str]): A list of utterance IDs.
speakers (List[str]): A list of speakers corresponding to utterances.
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
frame_length (float): Frame length in milliseconds (default: 25.0)
frame_shift (float): Frame shift in milliseconds (default: 10.0)
"""
def __init__(
self,
aud_paths,
aud_durations_ms,
tgt,
tgt_dict,
ids,
speakers,
num_mel_bins=80,
frame_length=25.0,
frame_shift=10.0,
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
]
assert len(aud_paths) > 0
assert len(aud_paths) == len(aud_durations_ms)
assert len(aud_paths) == len(tgt)
assert len(aud_paths) == len(ids)
assert len(aud_paths) == len(speakers)
self.aud_paths = aud_paths
self.tgt_dict = tgt_dict
self.tgt = tgt
self.ids = ids
self.speakers = speakers
self.num_mel_bins = num_mel_bins
self.frame_length = frame_length
self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater(
0,
1,
pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(),
move_eos_to_beginning=True,
)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
sound, sample_rate = torchaudio.load_wav(path)
output = kaldi.fbank(
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
)
output_cmvn = data_utils.apply_mv_norm(output)
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
def __len__(self):
return len(self.aud_paths)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return self.s2s_collater.collate(samples)
def num_tokens(self, index):
return self.frame_sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.frame_sizes[index],
len(self.tgt[index]) if self.tgt is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This module contains collection of classes which implement
collate functionalities for various tasks.
Collaters should know what data to expect for each sample
and they should pack / collate them into batches
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
"""
def __init__(
self,
feature_index=0,
label_index=1,
pad_index=1,
eos_index=2,
move_eos_to_beginning=True,
):
self.feature_index = feature_index
self.label_index = label_index
self.pad_index = pad_index
self.eos_index = eos_index
self.move_eos_to_beginning = move_eos_to_beginning
def _collate_frames(self, frames):
"""Convert a list of 2d frames into a padded 3d tensor
Args:
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
len_max = max(frame.size(0) for frame in frames)
f_dim = frames[0].size(1)
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
for i, v in enumerate(frames):
res[i, : v.size(0)] = v
return res
def collate(self, samples):
"""
utility function to collate samples into batch for speech recognition.
"""
if len(samples) == 0:
return {}
# parse samples into torch tensors
parsed_samples = []
for s in samples:
# skip invalid samples
if s["data"][self.feature_index] is None:
continue
source = s["data"][self.feature_index]
if isinstance(source, (np.ndarray, np.generic)):
source = torch.from_numpy(source)
target = s["data"][self.label_index]
if isinstance(target, (np.ndarray, np.generic)):
target = torch.from_numpy(target).long()
elif isinstance(target, list):
target = torch.LongTensor(target)
parsed_sample = {"id": s["id"], "source": source, "target": target}
parsed_samples.append(parsed_sample)
samples = parsed_samples
id = torch.LongTensor([s["id"] for s in samples])
frames = self._collate_frames([s["source"] for s in samples])
# sort samples by descending number of frames
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
frames_lengths, sort_order = frames_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
frames = frames.index_select(0, sort_order)
target = None
target_lengths = None
prev_output_tokens = None
if samples[0].get("target", None) is not None:
ntokens = sum(len(s["target"]) for s in samples)
target = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, sort_order)
target_lengths = torch.LongTensor(
[s["target"].size(0) for s in samples]
).index_select(0, sort_order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=self.move_eos_to_beginning,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
"target": target,
"target_lengths": target_lengths,
"nsentences": len(samples),
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def calc_mean_invstddev(feature):
if len(feature.size()) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = feature.mean(0)
var = feature.var(0)
# avoid division by ~zero
eps = 1e-8
if (var < eps).any():
return mean, 1.0 / (torch.sqrt(var) + eps)
return mean, 1.0 / torch.sqrt(var)
def apply_mv_norm(features):
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
# and normalization is not possible, so return the item as it is
if features.size(0) < 2:
return features
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
"""
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
Args:
lengths: a (B, )-shaped tensor
Return:
max_length: maximum length of B sequences
encoder_padding_mask: a (max_length, B) binary mask, where
[t, b] = 0 for t < lengths[b] and 1 otherwise
TODO:
kernelize this function if benchmarking shows this function is slow
"""
max_lengths = torch.max(lengths).item()
bsz = lengths.size(0)
encoder_padding_mask = torch.arange(
max_lengths
).to( # a (T, ) tensor with [0, ..., T-1]
lengths.device
).view( # move to the right device
1, max_lengths
).expand( # reshape to (1, T)-shaped tensor
bsz, -1
) >= lengths.view( # expand to (B, T)-shaped tensor
bsz, 1
).expand(
-1, max_lengths
)
if not batch_first:
return encoder_padding_mask.t(), max_lengths
else:
return encoder_padding_mask, max_lengths
def encoder_padding_mask_to_lengths(
encoder_padding_mask, max_lengths, batch_size, device
):
"""
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
Conventionally, encoder output contains a encoder_padding_mask, which is
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
[b] denotes the valid length of b-th sequence
Args:
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
indicating all are valid
Return:
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
number of valid elements of b-th sequence
max_lengths: maximum length of all sequence, if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(0)
batch_size: batch size; if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(1)
device: which device to put the result on
"""
if encoder_padding_mask is None:
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment