Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from typing import Dict, List, Optional
import torch
from torch import Tensor
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
from fairseq.modules import LayerNorm
from speech2c.models.modules.multihead_attention import MultiheadAttention
class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
):
super().__init__(
cfg,
no_encoder_attn,
add_bias_kv,
add_zero_attn,
)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_bias=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.dataclass import ChoiceEnum
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
SamePad,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import index_put
from fairseq.distributed import fsdp_wrap
from fairseq.models.wav2vec.utils import pad_to_multiple
from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
from speech2c.models.modules.multihead_attention import MultiheadAttention
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
class TransformerEncoder(W2vTransformerEncoder):
def __init__(self, args):
super().__init__(args)
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.required_seq_len_multiple = args.required_seq_len_multiple
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
layers = []
for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=self.use_rel_pos_enc,
)
if args.checkpoint_activations:
layer = fsdp_wrap(layer)
layer = checkpoint_wrapper(layer)
layers.append(layer)
self.layers = nn.ModuleList(layers)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
if tgt_layer is not None:
# unpad if needed
if pad_length > 0:
layer_results.append(
(
x[:-pad_length],
z[:, :-pad_length, :-pad_length]
if z is not None
else z,
)
)
else:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
pos_bias=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import logging
import copy
import contextlib
from typing import Dict, List, Optional, Tuple
import torch
from dataclasses import dataclass, field
from fairseq.data.dictionary import Dictionary
from fairseq.models import register_model
from fairseq.models.hubert import HubertConfig, HubertModel
from fairseq.models.transformer import Embedding
from torch import Tensor
from speech2c.tasks.speech2c_pretraining import (
Speech2cPretrainingConfig,
Speech2cPretrainingTask,
)
from speech2c.models.modules.transformer_decoder import TransformerDecoderScriptable
from speech2c.models.modules.transformer_encoder import TransformerEncoder
logger = logging.getLogger(__name__)
@dataclass
class Speech2cConfig(HubertConfig):
use_rel_pos_enc: bool = field(
default=False,
metadata={"help": "whether to use relative positional encoding"},
)
# decoder
decoder_layers: int = field(
default=6, metadata={"help": "num decoder layers in the transformer"}
)
decoder_embed_dim: int = field(
default=768, metadata={"help": "decoder embedding dimension"}
)
decoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_attention_heads: int = field(
default=12, metadata={"help": "num decoder attention heads"}
)
decoder_normalize_before: bool = field(
default=False,
metadata={"help": "apply layernorm before each decoder block"},
)
decoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a tarnsformer layer"},
)
share_decoder_input_output_embed: bool = field(
default=False,
metadata={"help": "share decoder input and output embeddings"},
)
decoder_output_dim: int = field(
default=768, metadata={"help": "decoder output dimension"}
)
max_target_positions: int = field(
default=3000, metadata={"help": "max target position"}
)
no_scale_embedding: bool = field(
default=False,
metadata={"help": "not scale embedding"},
)
adaptive_input: bool = field(
default=False,
metadata={"help": "adaptive input"},
)
quant_noise_pq: int = field(
default=0, metadata={"help": "quant noise pq"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "decoder learnable positional embedding"},
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={"help": "no token positional embeddings"},
)
decoder_dict_size: int = field(
default=-1,
metadata={"help": "decoder dictionary dimension, only used for fine-tuning"},
)
# FP16 optimization
required_seq_len_multiple: int = field(
default=1,
metadata={
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
},
)
crop_seq_to_multiple: int = field(
default=1,
metadata={
"help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
},
)
@register_model("speech2c", dataclass=Speech2cConfig)
class Speech2cModel(HubertModel):
def __init__(
self,
cfg: Speech2cConfig,
task_cfg: Speech2cPretrainingConfig,
dictionaries: List[Dictionary],
) -> None:
super().__init__(cfg, task_cfg, dictionaries)
logger.info(f"Speech2cModel Config: {cfg}")
self.encoder = TransformerEncoder(cfg)
self.add_decoder = task_cfg.add_decoder
if task_cfg.add_decoder:
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size
cut_dictionary = copy.deepcopy(dictionaries[0])
if cfg.decoder_dict_size != -1:
cut_dictionary.symbols = cut_dictionary.symbols[:cfg.decoder_dict_size]
decoder_embed_tokens = build_embedding(
cut_dictionary, cfg.decoder_embed_dim
)
self.decoder = TransformerDecoderScriptable(cfg, cut_dictionary, decoder_embed_tokens)
@classmethod
def build_model(cls, cfg: Speech2cConfig, task: Speech2cPretrainingTask):
"""Build a new model instance."""
model = Speech2cModel(cfg, task.cfg, task.dictionaries)
return model
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
prev_output_tokens: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
if self.add_decoder:
encoder_out = {
"encoder_out": [x.transpose(0, 1)], # T x B x C
"encoder_padding_mask": [padding_mask], # B x T
}
assert prev_output_tokens is not None
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
result['decoder_out'] = decoder_out
return result
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
res = self.forward(
net_input["source"],
padding_mask=net_input["padding_mask"],
mask=False,
features_only=True
)
encoder_out = {
"encoder_out": [res["x"].transpose(0, 1)], # T x B x C
"encoder_padding_mask": [res["padding_mask"]], # B x T
}
return encoder_out
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
prev_output_tokens: Optional[torch.Tensor] = None,
ft: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad() if not ft else contextlib.ExitStack():
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
if self.add_decoder:
encoder_out = {
"encoder_out": [feature.transpose(0, 1)], # T x B x C
"encoder_padding_mask": [res["padding_mask"]], # B x T
}
assert prev_output_tokens is not None
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
)
else:
decoder_out = None
return feature, res["padding_mask"], decoder_out
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from argparse import Namespace
from omegaconf import II
import torch.nn as nn
from dataclasses import dataclass, field
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
from fairseq.tasks import FairseqTask
@dataclass
class Speech2cAsrConfig(HubertAsrConfig):
# for decoder
decoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a decoder layer in hubert"},
)
add_decoder: bool = II("task.add_decoder")
@dataclass
class Speech2cCtcConfig(Speech2cAsrConfig):
pass
@register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
class Speech2cCtc(BaseFairseqModel):
def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
return cls(cfg, w2v_encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if "encoder_out" not in net_output:
return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
if "encoder_out_for_ctc" in net_output:
logits = net_output["encoder_out_for_ctc"]
else:
logits = net_output["encoder_out"]
if isinstance(logits, list):
logits = logits[0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
@property
def encoder(self):
return self.w2v_encoder
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
@property
def decoder(self):
return self.w2v_encoder.w2v_model.decoder
class Speech2cEncoder(FairseqEncoder):
def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"decoder_layerdrop": cfg.decoder_layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
w2v_args.task.add_decoder = cfg.add_decoder
task = tasks.setup_task(w2v_args.task)
if state is not None and "task_state" in state:
# This will load the stored "dictionaries" object
task.load_state_dict(state["task_state"])
model = task.build_model(w2v_args.model)
if state is not None and not cfg.no_pretrained_weights:
if "decoder.embed_tokens.weight" in state["model"]:
del state["model"]["decoder.embed_tokens.weight"]
if "decoder.output_projection.weight" in state["model"]:
del state["model"]["decoder.output_projection.weight"]
# set strict=False because we omit some modules
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = model.mask_emb.size(0)
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
ft = self.freeze_finetune_updates <= self.num_updates
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
"prev_output_tokens": prev_output_tokens,
"ft": ft,
}
x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
"decoder_out": decoder_out,
}
def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
if isinstance(encoder_out["encoder_out"], list):
encoder_out["encoder_out"] = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
else:
encoder_out["encoder_out"] = encoder_out[
"encoder_out"
].index_select(1, new_order)
if encoder_out["encoder_padding_mask"] is not None:
if isinstance(encoder_out["encoder_padding_mask"], list):
encoder_out["encoder_padding_mask"] = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
else:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
if isinstance(encoder_out["decoder_out"], list):
encoder_out["decoder_out"] = (
[] if len(encoder_out["decoder_out"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
)
else:
encoder_out["decoder_out"] = encoder_out[
"decoder_out"
].index_select(0, new_order)
if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
if isinstance(encoder_out["encoder_out_for_ctc"], list):
encoder_out["encoder_out_for_ctc"] = (
[] if len(encoder_out["encoder_out_for_ctc"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
)
else:
encoder_out["encoder_out_for_ctc"] = encoder_out[
"encoder_out_for_ctc"
].index_select(1, new_order)
return encoder_out
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
encoder_out = self.w2v_model.forward_torchscript(net_input)
assert self.proj is not None
encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from fairseq.models import (
register_model_architecture,
)
from fairseq.models.transformer_lm import base_lm_architecture
@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
def transformer_lm_t5(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
args.decoder_layers = getattr(args, "decoder_layers", 20)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
from typing import Dict, List, Optional
import sys
import torch
import torch.nn as nn
from fairseq import search, utils
from fairseq.data import data_utils
from fairseq.models import FairseqIncrementalDecoder
from torch import Tensor
from fairseq.ngram_repeat_block import NGramRepeatBlock
from speech2c.models.modules.ctc_prefix_score import CTCPrefixScore
import numpy
CTC_SCORING_RATIO = 7.0
class SequenceGenerator(nn.Module):
def __init__(
self,
models,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
max_len=0,
min_len=1,
normalize_scores=True,
len_penalty=1.0,
unk_penalty=0.0,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
search_strategy=None,
eos=None,
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
ctc_weight=0.0,
):
"""Generates translations of a given source sentence.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models,
currently support fairseq.models.TransformerModel for scripting
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence)
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
super().__init__()
if isinstance(models, EnsembleModel):
self.model = models
else:
self.model = EnsembleModel(models)
self.tgt_dict = tgt_dict
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos() if eos is None else eos
self.blank = self.tgt_dict.index("<s>")
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
if symbols_to_strip_from_output is not None
else {self.eos}
)
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.max_len = max_len or self.model.max_decoder_positions()
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.temperature = temperature
self.match_source_len = match_source_len
if no_repeat_ngram_size > 0:
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
else:
self.repeat_ngram_blocker = None
assert temperature > 0, "--temperature must be greater than 0"
self.search = (
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
)
# We only need to set src_lengths in LengthConstrainedBeamSearch.
# As a module attribute, setting it would break in multithread
# settings when the model is shared.
self.should_set_src_lengths = (
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
)
self.model.eval()
self.lm_model = lm_model
self.lm_weight = lm_weight
self.ctc_weight = ctc_weight
if self.lm_model is not None:
self.lm_model.eval()
def cuda(self):
self.model.cuda()
return self
@torch.no_grad()
def forward(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
"""Generate a batch of translations.
Args:
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, prefix_tokens, bos_token=bos_token)
# TODO(myleott): unused, deprecate after pytorch-translate migration
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
"""Iterate over a batched dataset and yield individual translations.
Args:
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
"""
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if "net_input" not in s:
continue
input = s["net_input"]
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items() if k != "prev_output_tokens"
}
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(encoder_input)
if timer is not None:
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
for i, id in enumerate(s["id"].data):
# remove padding
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
ref = (
utils.strip_pad(s["target"].data[i, :], self.pad)
if s["target"] is not None
else None
)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
"""Generate translations. Match the api of other fairseq generators.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
constraints (torch.LongTensor, optional): force decoder to include
the list of constraints
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, **kwargs)
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(self.model.models_size)
],
)
net_input = sample["net_input"]
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
src_lengths = (
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
)
elif "source" in net_input:
src_tokens = net_input["source"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
elif "features" in net_input:
src_tokens = net_input["features"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
raise NotImplementedError(
"Target-side constraints were provided, but search method doesn't support them"
)
# Initialize constraints, when active
self.search.init_constraints(constraints, beam_size)
max_len: int = -1
if self.match_source_len:
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
self.max_len - 1,
)
assert (
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# compute the encoder output for each beam
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
encoder_outs = self.model.forward_encoder(net_input)
# Get CTC lprobs and prep ctc_scorer
if self.ctc_weight > 0:
ctc_lprobs = self.model.models[0].get_normalized_probs(
encoder_outs[0], log_probs=True
).contiguous().transpose(0, 1) # (B, T, C) from the encoder
hyp = {}
ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
hyp["ctc_score_prev"] = 0.0
ctc_beam = min(ctc_lprobs.shape[-1], int(beam_size * CTC_SCORING_RATIO))
ctc_hyps = {str(self.eos): hyp}
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
# ensure encoder_outs is a List.
assert encoder_outs is not None
# initialize buffers
scores = (
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
) # +1 for eos; pad is never chosen for scoring
tokens = (
torch.zeros(bsz * beam_size, max_len + 2)
.to(src_tokens)
.long()
.fill_(self.pad)
) # +2 for eos and pad
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn: Optional[Tensor] = None
# A list that indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
cands_to_ignore = (
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
) # forward and backward-compatible False mask
# list of completed sentences
finalized = torch.jit.annotate(
List[List[Dict[str, Tensor]]],
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
# a boolean array indicating if the sentence at the index is finished or not
finished = [False for i in range(bsz)]
num_remaining_sent = bsz # number of sentences remaining
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (
(torch.arange(0, bsz) * beam_size)
.unsqueeze(1)
.type_as(tokens)
.to(src_tokens.device)
)
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
reorder_state: Optional[Tensor] = None
batch_idxs: Optional[Tensor] = None
original_batch_idxs: Optional[Tensor] = None
if "id" in sample and isinstance(sample["id"], Tensor):
original_batch_idxs = sample["id"]
else:
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
batch_idxs
)
reorder_state.view(-1, beam_size).add_(
corr.unsqueeze(-1) * beam_size
)
original_batch_idxs = original_batch_idxs[batch_idxs]
self.model.reorder_incremental_state(incremental_states, reorder_state)
encoder_outs = self.model.reorder_encoder_out(
encoder_outs, reorder_state
)
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
lprobs, avg_attn_scores = self.model.forward_decoder(
tokens[:, : step + 1],
encoder_outs,
incremental_states,
self.temperature,
)
if self.ctc_weight > 0 and step != 0:
# lprobs[:, self.blank] = -math.inf # never select blank
ctc_lprobs = lprobs.clone()
ctc_lprobs[:, self.blank] = -math.inf # never select blank
_, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
for b in range(tokens.size(0)):
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
ctc_scores, ctc_states = ctc_prefix_score(
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
)
lprobs[b] = lprobs[b]
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
).to(device="cuda")
for j in range(len(local_best_ids[b])):
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
elif self.ctc_weight > 0 and step == 0:
ctc_lprobs = lprobs.clone()
ctc_lprobs[:, self.blank] = -math.inf # never select blank
_, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
for b in range(tokens.size(0)):
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
ctc_scores, ctc_states = ctc_prefix_score(
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
)
lprobs[b] = lprobs[b]
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
).to(device="cuda")
for j in range(len(local_best_ids[b])):
if b == 0:
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
if self.lm_model is not None:
lm_out = self.lm_model(tokens[:, : step + 1])
probs = self.lm_model.get_normalized_probs(
lm_out, log_probs=True, sample=None
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs
# handle prefix tokens (possibly with different lengths)
if (
prefix_tokens is not None
and step < prefix_tokens.size(1)
and step < max_len
):
lprobs, tokens, scores = self._prefix_tokens(
step, lprobs, scores, tokens, prefix_tokens, beam_size
)
elif step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
lprobs[:, self.blank] = -math.inf # never select blank
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None:
if attn is None:
attn = torch.empty(
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
).to(scores)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs)
eos_bbsz_idx = torch.empty(0).to(
tokens
) # indices of hypothesis ending with eos (finished sentences)
eos_scores = torch.empty(0).to(
scores
) # scores of hypothesis ending with eos (finished sentences)
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)
if self.repeat_ngram_blocker is not None:
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
tokens[:, : step + 1],
original_batch_idxs,
)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
# Shape of eos_mask: (batch size, beam size)
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
# only consider eos when it's among the top beam_size indices
# Now we know what beam item(s) to finish
# Shape: 1d list of absolute-numbered
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents: List[int] = []
if eos_bbsz_idx.numel() > 0:
eos_scores = torch.masked_select(
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = self.finalize_hypos(
step,
eos_bbsz_idx,
eos_scores,
tokens,
scores,
finalized,
finished,
beam_size,
attn,
src_lengths,
max_len,
)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
if self.search.stop_on_max_len and step >= max_len:
break
assert step < max_len, f"{step} < {max_len}"
# Remove finalized sentences (ones for which {beam_size}
# finished hypotheses have been generated) from the batch.
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(
bsz, dtype=torch.bool, device=cand_indices.device
)
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(
bsz, device=cand_indices.device
).masked_select(batch_mask)
# Choose the subset of the hypothesized constraints that will continue
self.search.prune_sentences(batch_idxs)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
cands_to_ignore = cands_to_ignore[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(
new_bsz * beam_size, attn.size(1), -1
)
bsz = new_bsz
else:
batch_idxs = None
# Set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# Rewrite the operator since the element wise or is not supported in torchscript.
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just
# the hypos with the smallest values in active_mask.
# {active_hypos} indicates which {beam_size} hypotheses
# from the list of {2 * beam_size} candidates were
# selected. Shapes: (batch size, beam size)
new_cands_to_ignore, active_hypos = torch.topk(
active_mask, k=beam_size, dim=1, largest=False
)
# update cands_to_ignore to ignore any finalized hypos.
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
# Make sure there is at least one active item for each sentence in the batch.
assert (~cands_to_ignore).any(dim=1).all()
# update cands_to_ignore to ignore any finalized hypos
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
# can be selected more than once).
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
# Set the tokens for each beam (can select the same row more than once)
tokens[:, : step + 1] = torch.index_select(
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
)
# Select the next token for each of them
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
cand_indices, dim=1, index=active_hypos
)
if step > 0:
scores[:, :step] = torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx
)
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
cand_scores, dim=1, index=active_hypos
)
# Update constraints based on which candidates were selected for the next beam
self.search.update_constraints(active_hypos)
# copy attention for active hypotheses
if attn is not None:
attn[:, :, : step + 2] = torch.index_select(
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
)
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
scores = torch.tensor(
[float(elem["score"].item()) for elem in finalized[sent]]
)
_, sorted_scores_indices = torch.sort(scores, descending=True)
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
finalized[sent] = torch.jit.annotate(
List[Dict[str, Tensor]], finalized[sent]
)
return finalized
def _prefix_tokens(
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
):
"""Handle prefix tokens"""
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
eos_mask = prefix_toks.eq(self.eos)
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
:, 0, 1 : step + 1
]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
assert (first_beam == target_prefix).all()
# copy tokens, scores and lprobs from the first beam to all beams
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
return lprobs, tokens, scores
def replicate_first_beam(self, tensor, mask, beam_size: int):
tensor = tensor.view(-1, beam_size, tensor.size(-1))
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
def finalize_hypos(
self,
step: int,
bbsz_idx,
eos_scores,
tokens,
scores,
finalized: List[List[Dict[str, Tensor]]],
finished: List[bool],
beam_size: int,
attn: Optional[Tensor],
src_lengths,
max_len: int,
):
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
A sentence is finalized when {beam_size} finished items have been collected for it.
Returns number of sentences (not beam items) being finalized.
These will be removed from the batch and not processed further.
Args:
bbsz_idx (Tensor):
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors.
# tokens is (batch * beam, max_len). So the index_select
# gets the newly EOS rows, then selects cols 1..{step + 2}
tokens_clone = tokens.index_select(0, bbsz_idx)[
:, 1 : step + 2
] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = (
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
if attn is not None
else None
)
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
# cum_unfin records which sentences in the batch are finished.
# It helps match indexing between (a) the original sentences
# in the batch and (b) the current, possibly-reduced set of
# sentences.
cum_unfin: List[int] = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
unfin_idx = bbsz_idx // beam_size
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
# Create a set of "{sent}{unfin_idx}", where
# "unfin_idx" is the index in the current (possibly reduced)
# list of sentences, and "sent" is the index in the original,
# unreduced batch
# For every finished beam item
# sentence index in the current (possibly reduced) batch
seen = (sent << 32) + unfin_idx
unique_seen: List[int] = torch.unique(seen).tolist()
if self.match_source_len:
condition = step > torch.index_select(src_lengths, 0, unfin_idx)
eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
sent_list: List[int] = sent.tolist()
for i in range(bbsz_idx.size()[0]):
# An input sentence (among those in a batch) is finished when
# beam_size hypotheses have been collected for it
if len(finalized[sent_list[i]]) < beam_size:
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i]
else:
hypo_attn = torch.empty(0)
finalized[sent_list[i]].append(
{
"tokens": tokens_clone[i],
"score": eos_scores[i],
"attention": hypo_attn, # src_len x tgt_len
"alignment": torch.empty(0),
"positional_scores": pos_scores[i],
}
)
newly_finished: List[int] = []
for unique_s in unique_seen:
# check termination conditions for this sentence
unique_sent: int = unique_s >> 32
unique_unfin_idx: int = unique_s - (unique_sent << 32)
if not finished[unique_sent] and self.is_finished(
step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
):
finished[unique_sent] = True
newly_finished.append(unique_unfin_idx)
return newly_finished
def is_finished(
self,
step: int,
unfin_idx: int,
max_len: int,
finalized_sent_len: int,
beam_size: int,
):
"""
Check whether decoding for a sentence is finished, which
occurs when the list of finalized sentences has reached the
beam size, or when we reach the maximum length.
"""
assert finalized_sent_len <= beam_size
if finalized_sent_len == beam_size or step == max_len:
return True
return False
class EnsembleModel(nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models_size = len(models)
# method '__len__' is not supported in ModuleList for torch script
self.single_model = models[0]
self.models = nn.ModuleList(models)
self.has_incremental: bool = False
if all(
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
for m in models
):
self.has_incremental = True
def forward(self):
pass
def has_encoder(self):
return hasattr(self.single_model, "encoder")
def has_incremental_states(self):
return self.has_incremental
def max_decoder_positions(self):
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
@torch.jit.export
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
return [model.encoder.forward_torchscript(net_input) for model in self.models]
@torch.jit.export
def forward_decoder(
self,
tokens,
encoder_outs: List[Dict[str, List[Tensor]]],
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
temperature: float = 1.0,
):
log_probs = []
avg_attn: Optional[Tensor] = None
encoder_out: Optional[Dict[str, List[Tensor]]] = None
for i, model in enumerate(self.models):
if self.has_encoder():
encoder_out = encoder_outs[i]
# decode each model
if self.has_incremental_states():
decoder_out = model.decoder.forward(
tokens,
encoder_out=encoder_out,
incremental_state=incremental_states[i],
)
else:
if hasattr(model, "decoder"):
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
else:
decoder_out = model.forward(tokens)
attn: Optional[Tensor] = None
decoder_len = len(decoder_out)
if decoder_len > 1 and decoder_out[1] is not None:
if isinstance(decoder_out[1], Tensor):
attn = decoder_out[1]
else:
attn_holder = decoder_out[1]["attn"]
if isinstance(attn_holder, Tensor):
attn = attn_holder
elif attn_holder is not None:
attn = attn_holder[0]
if attn is not None:
attn = attn[:, -1, :]
decoder_out_tuple = (
decoder_out[0][:, -1:, :].div_(temperature),
None if decoder_len <= 1 else decoder_out[1],
)
probs = model.get_normalized_probs(
decoder_out_tuple, log_probs=True, sample=None
)
probs = probs[:, -1, :]
if self.models_size == 1:
return probs, attn
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
self.models_size
)
if avg_attn is not None:
avg_attn.div_(self.models_size)
return avg_probs, avg_attn
@torch.jit.export
def reorder_encoder_out(
self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
new_outs: List[Dict[str, List[Tensor]]] = []
if not self.has_encoder():
return new_outs
for i, model in enumerate(self.models):
assert encoder_outs is not None
new_outs.append(
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
)
return new_outs
@torch.jit.export
def reorder_incremental_state(
self,
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
new_order,
):
if not self.has_incremental_states():
return
for i, model in enumerate(self.models):
model.decoder.reorder_incremental_state_scripting(
incremental_states[i], new_order
)
class SequenceGeneratorWithAlignment(SequenceGenerator):
def __init__(
self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
):
"""Generates translations of a given source sentence.
Produces alignments following "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
left_pad_target (bool, optional): Whether or not the
hypothesis should be left padded or not when they are
teacher forced for generating alignments.
"""
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
self.left_pad_target = left_pad_target
if print_alignment == "hard":
self.extract_alignment = utils.extract_hard_alignment
elif print_alignment == "soft":
self.extract_alignment = utils.extract_soft_alignment
@torch.no_grad()
def generate(self, models, sample, **kwargs):
finalized = super()._generate(sample, **kwargs)
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.shape[0]
beam_size = self.beam_size
(
src_tokens,
src_lengths,
prev_output_tokens,
tgt_tokens,
) = self._prepare_batch_for_alignment(sample, finalized)
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
else:
attn = [
finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
for i in range(bsz * beam_size)
]
if src_tokens.device != "cpu":
src_tokens = src_tokens.to("cpu")
tgt_tokens = tgt_tokens.to("cpu")
attn = [i.to("cpu") for i in attn]
# Process the attn matrix to extract hard alignments.
for i in range(bsz * beam_size):
alignment = self.extract_alignment(
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
)
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
return finalized
def _prepare_batch_for_alignment(self, sample, hypothesis):
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.shape[0]
src_tokens = (
src_tokens[:, None, :]
.expand(-1, self.beam_size, -1)
.contiguous()
.view(bsz * self.beam_size, -1)
)
src_lengths = sample["net_input"]["src_lengths"]
src_lengths = (
src_lengths[:, None]
.expand(-1, self.beam_size)
.contiguous()
.view(bsz * self.beam_size)
)
prev_output_tokens = data_utils.collate_tokens(
[beam["tokens"] for example in hypothesis for beam in example],
self.pad,
self.eos,
self.left_pad_target,
move_eos_to_beginning=True,
)
tgt_tokens = data_utils.collate_tokens(
[beam["tokens"] for example in hypothesis for beam in example],
self.pad,
self.eos,
self.left_pad_target,
move_eos_to_beginning=False,
)
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
class EnsembleModelWithAlignment(EnsembleModel):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
avg_attn = None
for model in self.models:
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
attn = decoder_out[1]["attn"][0]
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
if len(self.models) > 1:
avg_attn.div_(len(self.models))
return avg_attn
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import logging
from dataclasses import dataclass, field
from fairseq.data import Dictionary
from fairseq.tasks import register_task
from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig, HubertPretrainingTask, LabelEncoder
from speech2c.data.speech2c_dataset import Speech2cDataset
logger = logging.getLogger(__name__)
@dataclass
class Speech2cPretrainingConfig(HubertPretrainingConfig):
add_decoder: bool = field(
default=False,
metadata={"help": "whether to add decoder for CE Loss on code"},
)
# For inference
ctc_weight: float = field(
default=0.0,
metadata={"help": "ctc weight during inference"},
)
@register_task("speech2c_pretraining", dataclass=Speech2cPretrainingConfig)
class Speech2cPretrainingTask(HubertPretrainingTask):
cfg: Speech2cPretrainingConfig
def load_dictionaries(self):
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
def load_dataset(self, split: str, **kwargs) -> None:
manifest = f"{self.cfg.data}/{split}.tsv"
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
pad_list = [dict.pad() for dict in dicts]
eos_list = [dict.eos() for dict in dicts]
procs = [LabelEncoder(dict) for dict in dicts]
paths = [
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
]
# hubert v1: pad_audio=True, random_crop=False;
self.datasets[split] = Speech2cDataset(
manifest,
sample_rate=self.cfg.sample_rate,
label_paths=paths,
label_rates=self.cfg.label_rate,
pad_list=pad_list,
eos_list=eos_list,
label_processors=procs,
max_keep_sample_size=self.cfg.max_keep_size,
min_keep_sample_size=self.cfg.min_sample_size,
max_sample_size=self.cfg.max_sample_size,
pad_audio=self.cfg.pad_audio,
normalize=self.cfg.normalize,
store_labels=False,
random_crop=self.cfg.random_crop,
single_target=self.cfg.single_target,
tgt_dict=dicts[0],
add_decoder=self.cfg.add_decoder,
fine_tuning=self.cfg.fine_tuning,
)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
from speech2c.squence_generator import SequenceGenerator
extra_gen_cls_kwargs = {
"ctc_weight": self.cfg.ctc_weight,
**extra_gen_cls_kwargs
}
return super().build_generator(
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
# Speech2S
<!--**Pre-trained models for speech related tasks**-->
[**Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation**](https://arxiv.org/abs/2210.17027)
- (Updating) Nov. 2022: release the code and models
- Nov. 2022: release preprint in [arXiv](https://arxiv.org/abs/2210.17027)
## Pre-Trained and Fine-tuned Models
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
| Speech2S_enes | Voxpopuli_en_v2 | - | [Google Drive](https://drive.google.com/file/d/1TYypFiEKoCixUro8FTTG23bRZYwAxhkX/view?usp=share_link) |
| Speech2S_enes | Voxpopuli_en_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/11RxeKznSrHcoP_KK9A1VgwRt3fNh_U_C/view?usp=share_link) |
| Speech2S_esen | Voxpopuli_es_v2 | - | [Google Drive](https://drive.google.com/file/d/1NoC7W-UtQZ-ugIptF1ex0ZlGJncsT1S4/view?usp=share_link) |
| Speech2S_esen | Voxpopuli_es_v2 | Voxpopuli_s2s | [Google Drive](https://drive.google.com/file/d/1eNcKw4ZWGmcABWXJxlf6MKocmiPrKSkH/view?usp=share_link) |
## Setup
```
cd Speech2S/speech2s
pip install --editable fairseq/
```
## Data Preparation
Please follow the steps of data preparation for S2ST in [here](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md).
## Pre-Training
```
cd speech2s/stpretrain_scripts
base_sc2c_enes.sh
```
## Finetune
```
cd speech2s/stpretrain_scripts
finetune_enes.sh
```
## Inference
```
cd speech2s/stpretrain_scripts
inference_ed.sh
```
## Results on Voxpopuli and Covst
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
## Reference
If you find our work is useful in your research, please cite the following paper:
```bibtex
@article{wei2022joint,
title={Joint Pre-Training with Speech and Bilingual Text for Direct Speech to Speech Translation},
author={Wei, Kun and Zhou, Long and Zhang, Ziqiang and Chen, Liping and Liu, Shujie and He, Lei and Li, Jinyu and Wei, Furu},
journal={arXiv preprint arXiv:2210.17027},
year={2022}
}
```
from . import data, tasks, criterions, models
# @package _group_
common:
fp16: true
log_format: json
log_interval: 100
tensorboard_logdir: tblog
seed: 1337
checkpoint:
save_interval: 1
keep_last_epochs: 1
keep_best_checkpoints: 5
best_checkpoint_metric: dec_accuracy
maximize_best_checkpoint_metric: true
restore_file: checkpoint_last.pt
distributed_training:
ddp_backend: legacy_ddp
find_unused_parameters: true
distributed_world_size: 1
distributed_port: -1
nprocs_per_node: 8
task:
_name: joint_sc2t_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: false # must be consistent with pre-training
labels: ["ltr"]
store_labels: true
single_target: true
add_decoder_target: true
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
dataset:
num_workers: 0
max_tokens: 1300000
skip_invalid_size_inputs_valid_test: true
train_subset: train_100
valid_subset: dev_other
required_batch_size_multiple: 1
criterion:
_name: ctc_ce
zero_infinity: true
optimization:
max_update: 40000
lr: [0.00001]
sentence_avg: true
update_freq: [2]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
weight_decay: 0.0
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speechut_asr
w2v_path: ???
apply_mask: true
mask_prob: 0.65
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.1
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
add_decoder: true
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 100
tensorboard_logdir: tblog
seed: 1337
checkpoint:
save_interval: 1
keep_last_epochs: 5
keep_best_checkpoints: 5
best_checkpoint_metric: dec_accuracy
maximize_best_checkpoint_metric: true
restore_file: checkpoint_last.pt
distributed_training:
ddp_backend: legacy_ddp
find_unused_parameters: true
distributed_world_size: 16
distributed_port: -1
nprocs_per_node: 8
task:
_name: joint_sc2t_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: true # must be consistent with pre-training
labels: ["ltr"]
store_labels: true
single_target: true
add_decoder_target: true
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
dataset:
num_workers: 0
max_tokens: 1300000
skip_invalid_size_inputs_valid_test: true
train_subset: train_100
valid_subset: dev_other
required_batch_size_multiple: 1
criterion:
_name: ctc_ce
zero_infinity: true
optimization:
max_update: 40000
lr: [0.00001]
sentence_avg: true
update_freq: [2]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
weight_decay: 0.0
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speechut_asr
w2v_path: ???
apply_mask: true
mask_prob: 0.5
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.0
activation_dropout: 0.1
attention_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
add_decoder: true
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 100
tensorboard_logdir: tblog
checkpoint:
save_interval: 1
keep_last_epochs: 5
keep_best_checkpoints: 5
best_checkpoint_metric: dec_accuracy
maximize_best_checkpoint_metric: true
restore_file: checkpoint_last.pt
distributed_training:
ddp_backend: legacy_ddp
find_unused_parameters: true
distributed_world_size: 24
distributed_port: -1
nprocs_per_node: 8
task:
_name: joint_sc2t_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: true # must be consistent with pre-training
labels: ["ltr"]
store_labels: true
single_target: true
add_decoder_target: true
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
dataset:
num_workers: 0
max_tokens: 1300000
skip_invalid_size_inputs_valid_test: true
train_subset: train_960
valid_subset: dev_other
required_batch_size_multiple: 1
criterion:
_name: ctc_ce
zero_infinity: true
optimization:
max_update: 40000
lr: [0.00001]
sentence_avg: true
update_freq: [2]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
weight_decay: 0.0
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speechut_asr
w2v_path: ???
apply_mask: true
mask_prob: 0.5
mask_channel_prob: 0.25
mask_channel_length: 64
layerdrop: 0.0
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
add_decoder: true
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
seed: 1337
tensorboard_logdir: tblog
checkpoint:
save_dir: ???
save_interval: 4
keep_last_epochs: 4
save_interval_updates: 50000
keep_interval_updates: -1
keep_interval_updates_pattern: 50000
# no_epoch_checkpoints: true
distributed_training:
ddp_backend: no_c10d
distributed_backend: 'nccl'
distributed_port: -1
distributed_world_size: 32
nprocs_per_node: 8
find_unused_parameters: true
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ???
label_rate: ${model.label_rate}
store_labels: true
sample_rate: 16000
max_sample_size: 250000
min_sample_size: 32000
pad_audio: false
random_crop: true
normalize: false # must be consistent with extractor
add_decoder_target: true
text_cfg:
seed: ${common.seed}
text_data: ???
data_config: config.yaml
sample_break_mode: eos
tokens_per_sample: 1024
shorten_method: "random_crop"
text_maxtokens_ratio: 1.5
dataset:
num_workers: 6
max_tokens: 1400000
skip_invalid_size_inputs_valid_test: true
validate_interval: ${checkpoint.save_interval}
validate_interval_updates: ${checkpoint.save_interval_updates}
required_batch_size_multiple: 1
criterion:
_name: speechut_criterion
pred_masked_weight: 1.0
pred_nomask_weight: 0.0
loss_weights: [10,]
label_smoothing: 0.1
u2t_ed_weight: 0.1
u2t_ctc_weight: 0.1
text_mum_weight: 0.5
optimization:
max_update: 400000
lr: [0.0005]
clip_norm: 10.0
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
model:
_name: speechut
label_rate: ???
skip_masked: false
skip_nomask: false
mask_prob: 0.80
extractor_mode: default
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
final_dim: 256
activation_fn: "gelu"
encoder_layers: 6
encoder_attention_heads: 8
encoder_layerdrop: 0.0
dropout_input: 0.1
dropout_features: 0.1
dropout: 0.1
attention_dropout: 0.1
feature_grad_mult: 0.1
untie_final_proj: true
activation_dropout: 0.0
use_rel_pos_enc: true
add_unit_encoder: true
add_text_ctc: true
mask_u2t: false
mix_with_unit: true
add_decoder: true
reset_decoder_embedding_config: true
text_transformer:
activation_fn: ${model.activation_fn}
dropout: ${model.dropout}
attention_dropout: ${model.attention_dropout}
activation_dropout: ${model.activation_dropout}
max_source_positions: 3000
max_target_positions: 3000
no_scale_embedding: true
layernorm_embedding: true
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
encoder:
embed_dim: 768
ffn_embed_dim: 3072
layers: 6
attention_heads: 8
normalize_before: false
learned_pos: true
layerdrop: ${model.encoder_layerdrop}
decoder:
layerdrop: 0.1
embed_dim: 768
ffn_embed_dim: 3072
layers: 6
attention_heads: 12
normalize_before: false
learned_pos: false
output_dim: 768
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
fp16_scale_tolerance: 0.1 # alleviate fp16 overflow issue
log_format: json
log_interval: 200
seed: 1234
tensorboard_logdir: tblog
checkpoint:
save_dir: ???
save_interval: 1
keep_last_epochs: 4
save_interval_updates: 10000
keep_interval_updates: -1
keep_interval_updates_pattern: 10000
# no_epoch_checkpoints: true
distributed_training:
ddp_backend: no_c10d
distributed_backend: 'nccl'
distributed_port: -1
distributed_world_size: 128
nprocs_per_node: 8
find_unused_parameters: true
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ???
label_rate: ${model.label_rate}
store_labels: true
sample_rate: 16000
max_sample_size: 250000
min_sample_size: 32000
pad_audio: false
random_crop: true
normalize: true # must be consistent with extractor
add_decoder_target: true
text_cfg:
seed: ${common.seed}
text_data: ???
data_config: config.yaml
sample_break_mode: eos
tokens_per_sample: 1024
shorten_method: "random_crop"
text_maxtokens_ratio: 1.4
dataset:
num_workers: 6
max_tokens: 900000
skip_invalid_size_inputs_valid_test: true
validate_interval: ${checkpoint.save_interval}
validate_interval_updates: ${checkpoint.save_interval_updates}
required_batch_size_multiple: 2
criterion:
_name: speechut_criterion
pred_masked_weight: 1.0
pred_nomask_weight: 0.0
loss_weights: [10,]
label_smoothing: 0.1
u2t_ed_weight: 0.1
u2t_ctc_weight: 0.1
text_mum_weight: 0.5
optimization:
max_update: 400000
lr: [0.0005]
clip_norm: 1.0
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
end_learning_rate: 0.00015 # for future longger pre-training, e.g. 600K step
model:
_name: speechut
label_rate: ???
encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
skip_masked: false
skip_nomask: false
mask_prob: 0.80
extractor_mode: layer_norm
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
final_dim: 768
activation_fn: "gelu"
encoder_layers: 12
encoder_attention_heads: 16
encoder_layerdrop: 0.0
dropout_input: 0.0
dropout_features: 0.0
dropout: 0.0
attention_dropout: 0.0
layer_norm_first: true
feature_grad_mult: 1.0
untie_final_proj: true
activation_dropout: 0.0
use_rel_pos_enc: true
add_unit_encoder: true
add_text_ctc: true
mask_u2t: false
mix_with_unit: true
add_decoder: true
reset_decoder_embedding_config: true
scaling_for_att: 32 # alleviate fp16 overflow issue
text_transformer:
activation_fn: ${model.activation_fn}
dropout: ${model.dropout}
attention_dropout: ${model.attention_dropout}
activation_dropout: ${model.activation_dropout}
max_source_positions: 3000
max_target_positions: 3000
no_scale_embedding: true
layernorm_embedding: true
no_token_positional_embeddings: true
share_decoder_input_output_embed: false
encoder:
embed_dim: 1024
ffn_embed_dim: 4096
layers: 12
attention_heads: 16
normalize_before: false
learned_pos: true
layerdrop: ${model.encoder_layerdrop}
decoder:
layerdrop: 0.1
embed_dim: 768
ffn_embed_dim: 3072
layers: 6
attention_heads: 12
normalize_before: false
learned_pos: false
output_dim: 768
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
importlib.import_module(
"speechut.criterions." + criterion_name
)
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import math
from argparse import Namespace
from dataclasses import dataclass, field
from omegaconf import II
from typing import Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
@dataclass
class CtcCeCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model: Optional[str] = field(
default=None,
metadata={
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon: Optional[str] = field(
default=None,
metadata={"help": "lexicon to use with wer_kenlm_model"},
)
wer_lm_weight: float = field(
default=2.0,
metadata={"help": "lm weight to use with wer_kenlm_model"},
)
wer_word_score: float = field(
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
wer_args: Optional[str] = field(
default=None,
metadata={
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
dec_weight: float = field(
default=0.5,
metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
label_smoothing: float = field(
default=0.1,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
class CtcCeCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
super().__init__(task)
self.blank_idx = (
task.target_dictionary.index(task.blank_symbol)
if hasattr(task, "blank_symbol")
else 0
)
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
if cfg.wer_args is not None:
(
cfg.wer_kenlm_model,
cfg.wer_lexicon,
cfg.wer_lm_weight,
cfg.wer_word_score,
) = eval(cfg.wer_args)
if cfg.wer_kenlm_model is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = cfg.wer_kenlm_model
dec_args.lexicon = cfg.wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity
self.sentence_avg = cfg.sentence_avg
self.dec_weight = cfg.dec_weight
self.report_accuracy = cfg.report_accuracy
self.ignore_prefix_size = cfg.ignore_prefix_size
self.eps = cfg.label_smoothing
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
if net_output["padding_mask"] is not None:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
else:
input_lengths = lprobs.new_full(
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {}
if "decoder_target" in sample:
if net_output["decoder_out"] is not None:
dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
logging_output["ctc_loss"] = loss.item()
loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
logging_output["dec_loss"] = dec_loss.item()
logging_output["dec_nll_loss"] = dec_nll_loss.item()
logging_output["dec_sample_size"] = dec_sample_size
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
logging_output["dec_n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
else:
logging_output["ctc_loss"] = loss.item()
loss = (1 - self.dec_weight) * loss
logging_output["dec_loss"] = 0
logging_output["dec_nll_loss"] = 0
logging_output["dec_sample_size"] = 1
if self.report_accuracy:
logging_output["dec_n_correct"] = 0
logging_output["total"] = 1
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
if not model.training and self.dec_weight < 1.0:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
def compute_ce_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.pad_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.pad_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample["decoder_target"]
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
if "dec_loss" in logging_outputs[0]:
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_scalar(
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_derived(
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("dec_n_correct", n_correct)
metrics.log_derived(
"dec_accuracy",
lambda meters: round(
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class SpeechUTCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
u2t_ed_weight: float = field(
default=0.1,
metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
)
u2t_ctc_weight: float = field(
default=0.0,
metadata={"help": "weights for text ED Loss, loss will be (hubert_loss + text_mum_weight * MUM_Loss + u2t_ed_weight * CE_Loss + u2t_ctc_weight * CTC_loss)"},
)
text_mum_weight: float = field(
default=0.0,
metadata={"help": "masked unit modeling weight from the text end"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
no_ctc_blank: bool = field(
default=False,
metadata={"help": "mask out the blank of ctc, only when dec_loss_type=ctc"},
)
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("speechut_criterion", dataclass=SpeechUTCriterionConfig)
class SpeechUTCriterion(FairseqCriterion):
def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
u2t_ed_weight=0.1,
u2t_ctc_weight=0,
text_mum_weight=0,
report_accuracy=False,
ignore_prefix_size=0,
label_smoothing=0,
no_ctc_blank=False,
):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
self.u2t_ed_weight = u2t_ed_weight
self.u2t_ctc_weight = u2t_ctc_weight
self.text_mum_weight = text_mum_weight
self.report_accuracy = report_accuracy
self.ignore_prefix_size = ignore_prefix_size
self.eps = label_smoothing
self.no_ctc_blank = no_ctc_blank
self.padding_idx = task.dictionaries[0].pad()
self.eos_idx = task.dictionaries[0].eos()
self.blank_idx = task.dictionaries[0].bos()
def compute_hubert_loss(self, model, net_output, reduction, preffix='', suffix=''):
loss = 0
sample_size = []
logging_output = {}
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"{preffix}loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size.append(targ_m_list[0].numel())
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"{preffix}loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size.append(targ_u_list[0].numel())
sample_size = np.mean(sample_size)
def compute_correct(logits, targets):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == targets
min = logits.argmin(-1) == targets
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
corr_m, count_m = compute_correct(logp_m, targ_m)
logging_output[f"correct_m_{i}{suffix}"] = corr_m
logging_output[f"count_m_{i}{suffix}"] = count_m
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
corr_u, count_u = compute_correct(logp_u, targ_u)
logging_output[f"correct_u_{i}{suffix}"] = corr_u
logging_output[f"count_u_{i}{suffix}"] = count_u
return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
reduction = "sum" if reduce else "none"
if "net_input" in sample:
unit_sample = text_sample = None
else:
unit_sample = sample.get("text_mono", None)
text_sample = sample.get("text_paired", None)
assert unit_sample is not None or text_sample is not None
sample = sample.get("speech")
### 1. S2U: do hubert forward and loss computation
sample["modality"] = "speech"
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss, sample_size, logging_output = self.compute_hubert_loss(
model,
net_output,
reduction,
)
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
### 2. do text U2T forward and loss computation
if text_sample is not None and (self.u2t_ctc_weight + self.u2t_ed_weight) > 0:
## 2.1 re-loading "target_list", in default case, target_list = [src_tokens],
## while in case of using "unit-phone-char" structure, target_list will be [ref_tokens]
text_sample["net_input"]["target_list"] = [
text_sample.get("ref_tokens", text_sample["net_input"]["src_tokens"].clone()),
]
text_net_output = model(**text_sample["net_input"])
text_sample_size = text_sample["ntokens"]
### 2.1 U2T_UCTC
if self.u2t_ctc_weight > 0:
text_ctc_loss = self.compute_ctc_loss(model, text_net_output, text_sample["target"], reduction=reduction)
loss += self.u2t_ctc_weight * text_ctc_loss * sample_size / text_sample_size
logging_output["text_ctc_loss"] = utils.item(text_ctc_loss)
logging_output["text_sample_size"] = text_sample_size
### 2.2 U2T_ED
if self.u2t_ed_weight > 0:
text_dec_loss, text_dec_nll_loss = self.compute_ce_loss(model, text_net_output["decoder_out"], text_sample, reduce=reduce)
loss += self.u2t_ed_weight * text_dec_loss * sample_size / text_sample_size
logging_output["text_dec_loss"] = utils.item(text_dec_loss)
logging_output["text_dec_nll_loss"] = utils.item(text_dec_nll_loss)
logging_output["text_sample_size"] = text_sample_size
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, text_net_output["decoder_out"], text_sample)
logging_output["correct_text_dec"] = utils.item(n_correct.data)
logging_output["count_text_dec"] = utils.item(total.data)
### 3. do unit MUM forward and loss computation
if unit_sample is not None and self.text_mum_weight > 0:
src_tokens = unit_sample["net_input"]["src_tokens"]
target = unit_sample.get("target", None)
target = src_tokens.clone() if target is None else target
unit_net_output = model.forward_mum(src_tokens, target)
loss_num, sample_size_mum, logging_output_mum = self.compute_hubert_loss(
model,
unit_net_output,
reduction,
preffix="mum_",
suffix="_mum",
)
loss += self.text_mum_weight * loss_num * sample_size / sample_size_mum
logging_output["unit_sample_size"] = sample_size_mum
logging_output.update(logging_output_mum)
logging_output = {
"loss": utils.item(loss) if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel() + (text_sample["id"].numel() if text_sample is not None else 0),
"sample_size": sample_size,
**logging_output,
}
return loss, sample_size, logging_output
def compute_ctc_loss(self, model, net_output, target, reduction):
logits = net_output["encoder_out_ctc"][0] # (T, B, C) from the code-encoder
if self.no_ctc_blank:
## set prob of <blank> to -inf
logits = logits.float()
logits[:, :, self.blank_idx] = -1000000.0
lprobs = F.log_softmax(logits.float(), dim=-1)
encoder_padding_mask = net_output["encoder_padding_mask"][0]
non_padding_mask = ~encoder_padding_mask
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (target != self.padding_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction=reduction,
zero_infinity=True,
)
return loss
def compute_ce_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample["target"]
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
if "text_sample_size" in logging_outputs[0]:
text_sample_size = sum(log.get("text_sample_size", 0) for log in logging_outputs)
for lk in logging_outputs[0].keys():
if lk.startswith("text_") and lk.endswith("_loss"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / text_sample_size / math.log(2), round=3)
if "unit_sample_size" in logging_outputs[0]:
unit_sample_size = sum(log.get("unit_sample_size", 0) for log in logging_outputs)
for lk in logging_outputs[0].keys():
if lk.startswith("mum_loss_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / unit_sample_size / math.log(2), round=3)
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from fairseq.data import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence, sample_ratios):
r, s = [], 0
for e, ratio in zip(sequence, sample_ratios):
curr_len = int(ratio * len(e))
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
# special handling for concatenating lang_pair_datasets
if getattr(self.datasets[0], "shuffle", False):
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
sizes = self.sizes
tgt_sizes = (
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
)
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import itertools
import logging
import io
import os
import sys
import time
from pathlib import Path
from typing import Any, List, Optional, Union, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils, Dictionary
from fairseq.data.fairseq_dataset import FairseqDataset
from fairseq.data.audio.audio_utils import (
read_from_stored_zip,
is_sf_audio_data,
)
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
logger = logging.getLogger(__name__)
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def load_audio(manifest_path, max_keep, min_keep, retry_times=5):
n_long, n_short = 0, 0
names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], []
for i in range(retry_times):
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
fname = items[0].split(":")
if len(fname) > 2:
if len(chunk_names) == 0 or fname[0] != chunk_names[-1]:
chunk_names.append(fname[0])
chunk_indices.append(len(names))
names.append(items[0])
inds.append(ind)
sizes.append(sz)
if len(names) == 0:
logger.warn(f"Fail to load manifest for the {i} time")
time.sleep(1)
continue
else:
break
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, names, inds, tot, sizes, chunk_names, chunk_indices
def load_label(label_path, inds, tot, retry_times=5):
for i in range(retry_times):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
if len(labels) == 0:
logger.warn(f"Fail to load label for the {i} time")
time.sleep(1)
continue
else:
break
assert (
len(labels) == tot
), f"number of labels does not match ({len(labels)} != {tot})"
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds, tot, retry_times=5):
for i in range(retry_times):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
if len(code_lengths) == 0:
logger.warn(f"Fail to load label for the {i} time")
time.sleep(1)
continue
else:
break
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def verify_label_lengths(
audio_sizes,
audio_rate,
label_path,
label_rate,
inds,
tot,
tol=0.1, # tolerance in seconds
):
if label_rate < 0:
logger.info(f"{label_path} is sequence label. skipped")
return
with open(label_path) as f:
lengths = [len(line.rstrip().split()) for line in f]
assert len(lengths) == tot
lengths = [lengths[i] for i in inds]
num_invalid = 0
for i, ind in enumerate(inds):
dur_from_audio = audio_sizes[i] / audio_rate
dur_from_label = lengths[i] / label_rate
if abs(dur_from_audio - dur_from_label) > tol:
logger.warning(
(
f"audio and label duration differ too much "
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
f"in line {ind+1} of {label_path}. Check if `label_rate` "
f"is correctly set (currently {label_rate}). "
f"num. of samples = {audio_sizes[i]}; "
f"label length = {lengths[i]}"
)
)
num_invalid += 1
if num_invalid > 0:
logger.warning(
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
)
class HubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
tgt_dict: Optional[Dictionary] = None,
add_decoder_target: bool = False,
fine_tuning: bool = False,
tgt_lang_idx: int = None,
tokenizer = None,
mbart_style_lang_id: bool = False,
retry_times: int = 5,
reduce_label_for_dec: bool = True,
):
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.random_crop = random_crop
self.tgt_dict = tgt_dict
self.add_decoder_target = add_decoder_target
self.fine_tuning = fine_tuning
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.single_target = single_target
self.epoch = 0
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, int)
else label_rates
)
self.store_labels = store_labels
if store_labels:
self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds, tot, retry_times) for p in label_paths
]
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.wav_sizes, sample_rate, label_path, label_rate, inds, tot
)
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
self.tgt_lang_idx = tgt_lang_idx
self.tokenizer = tokenizer
self.mbart_style_lang_id = mbart_style_lang_id
self.retry_times = retry_times
self.reduce_label_for_dec = reduce_label_for_dec
logger.info(
f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}"
)
def set_epoch(self, epoch):
self.epoch = epoch
def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1):
self.max_tokens = max_tokens
self.max_sentences = max_sentences
self.required_batch_size_multiple = required_batch_size_multiple
if isinstance(indices[0], np.ndarray):
batch_list = []
for indice in indices:
batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple)
batch_list.append(batch)
return batch_list
else:
return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple)
def shuffle_batches(self, batches, seed):
if isinstance(batches[0], list):
new_batches = []
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
for batch in batches:
np.random.shuffle(batch)
new_batches.extend(batch)
return new_batches
else:
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
_path, slice_ptr = parse_path(wav_path)
if len(slice_ptr) == 1:
import kaldiio
feat = kaldiio.load_mat(wav_path)
feat = torch.from_numpy(feat).float()
if self.normalize:
with torch.no_grad():
feat = F.layer_norm(feat, feat.shape[-1])
return feat
else:
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
wav_path = io.BytesIO(byte_data)
for i in range(self.retry_times):
if i < self.retry_times - 1:
try:
wav, cur_sample_rate = sf.read(wav_path)
break
except Exception as e:
logger.warn(f"Fail to load wav for the {i} time")
logger.warn(e)
time.sleep(1)
continue
else:
wav, cur_sample_rate = sf.read(wav_path)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.tokenizer is not None and self.fine_tuning:
label = self.tokenizer.encode(label)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def __getitem__(self, index):
wav = self.get_audio(index)
labels = self.get_labels(index)
return {"id": index, "source": wav, "label_list": labels}
def __len__(self):
return len(self.wav_sizes)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size, feat_dim,
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
if self.add_decoder_target:
if self.fine_tuning:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
else:
if self.tokenizer is not None:
decoder_label = [
# Set 48 for translate int to char and avoid \n
torch.cat(
(
torch.tensor(
self.tokenizer.sp.Encode(
"".join(
[chr(j + 48) for j in (
targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
).tolist()]
), out_type=int
)
),
torch.tensor([self.tgt_dict.eos()])
), dim=0
).long()
for i in range(targets_list[0].size(0))
]
else:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
if self.mbart_style_lang_id:
decoder_label = [
torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
for i in range(targets_list[0].size(0))
]
dec_ntokens = sum(x.size(0) for x in decoder_label)
decoder_target = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=False,
)
decoder_target_lengths = torch.tensor(
[x.size(0) for x in decoder_label], dtype=torch.long
)
prev_output_tokens = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=True,
)
if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
prev_output_tokens[:, 0] = self.tgt_lang_idx
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"decoder_target": decoder_target,
"decoder_target_lengths": decoder_target_lengths,
"dec_ntokens": dec_ntokens,
"lang_idx": self.tgt_lang_idx,
}
else:
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
def collater_audio(self, audios, audio_size, feat_dim=1):
collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim)
padding_mask = (
torch.BoolTensor(collated_audios.shape[0:2]).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
audio = audio.view(-1, feat_dim)
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios.squeeze(-1), padding_mask, audio_starts
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.wav_sizes[index]
return min(self.wav_sizes[index], self.max_sample_size)
@property
def sizes(self):
return np.array(self.wav_sizes)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
if len(self.chunk_names) > 0:
logger.info(f"ordered indices for epoch {self.epoch}")
with data_utils.numpy_seed(self.epoch):
self.chunk_order = np.random.permutation(len(self.chunk_names))
chunk_count = 0
tmp_sizes = []
tmp_indices = []
indice = []
for i in self.chunk_order:
chunk_count += 1
start = self.chunk_indices[i]
end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self)
size = list(self.sizes[start:end])
tmp_indices.extend(list(np.arange(start, end)))
tmp_sizes.extend(size)
if chunk_count % 10 == 0 or i == self.chunk_order[0]:
order = [np.random.permutation(len(tmp_indices))]
order.append(
np.minimum(
np.array(tmp_sizes),
self.max_sample_size,
)
)
sort_idx = np.lexsort(order)[::-1]
indice.append(np.array([tmp_indices[k] for k in sort_idx]))
tmp_indices = []
tmp_sizes =[]
return indice
else:
order = [np.random.permutation(len(self))]
order.append(
np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
)
return np.lexsort(order)[::-1]
else:
return np.arange(len(self))
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import logging
import numpy as np
import torch
import os
import itertools
from fairseq.data import FairseqDataset, data_utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
PrependTokenDataset,
data_utils,
indexed_dataset,
)
logger = logging.getLogger(__name__)
def load_langtriple_dataset(
data_path,
split,
src,
src_dict,
ref,
ref_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
lang_format="[{}]",
):
assert not truncate_source
def split_exists(split, src, ref, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
ref_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, ref, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
elif split_exists(split_k, tgt, ref, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
src_datasets.append(src_dataset)
ref_dataset = data_utils.load_indexed_dataset(
prefix + ref, ref_dict, dataset_impl
)
ref_datasets.append(ref_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{}-{} {} examples".format(
data_path, split_k, src, ref, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(ref_datasets)
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
ref_dataset = ref_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index(lang_format.format(src))
)
ref_dataset = AppendTokenDataset(
ref_dataset, ref_dict.index(lang_format.format(ref))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
)
eos = tgt_dict.index(lang_format.format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguageTripleDataset(
src_dataset,
src_dataset.sizes,
src_dict,
ref_dataset,
ref_dataset.sizes,
ref_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
def collate(
samples,
pad_idx,
eos_idx,
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
pad_to_multiple=1,
):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
None,
left_pad,
move_eos_to_beginning,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if (
alignment[:, 0].max().item() >= src_len - 1
or alignment[:, 1].max().item() >= tgt_len - 1
):
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(
align_tgt, return_inverse=True, return_counts=True
)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1.0 / align_weights.float()
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
ref_tokens = merge(
"reference",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor(
[s["source"].ne(pad_idx).long().sum() for s in samples]
)
ref_lengths = torch.LongTensor(
[s["reference"].ne(pad_idx).long().sum() for s in samples]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
ref_lengths = ref_lengths.index_select(0, sort_order)
ref_tokens = ref_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
elif input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"ref_tokens": ref_tokens,
"ref_lengths": ref_lengths,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
0, sort_order
)
if samples[0].get("alignment", None) is not None:
bsz, tgt_sz = batch["target"].shape
src_sz = batch["net_input"]["src_tokens"].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
if left_pad_source:
offsets[:, 0] += src_sz - src_lengths
if left_pad_target:
offsets[:, 1] += tgt_sz - tgt_lengths
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(
sort_order, offsets, src_lengths, tgt_lengths
)
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch["alignments"] = alignments
batch["align_weights"] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints.index_select(0, sort_order)
return batch
class LanguageTripleDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (~fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for teacher forcing (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
batches will be bucketed into the given number of batch shapes.
src_lang_id (int, optional): source language ID, if set, the collated batch
will contain a field 'src_lang_id' in 'net_input' which indicates the
source language of the samples.
tgt_lang_id (int, optional): target language ID, if set, the collated batch
will contain a field 'tgt_lang_id' which indicates the target language
of the samples.
"""
def __init__(
self,
src,
src_sizes,
src_dict,
ref,
ref_sizes,
ref_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
shuffle=True,
input_feeding=True,
remove_eos_from_source=False,
append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False,
eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
pad_to_multiple=1,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
if tgt is not None:
assert len(src) == len(
tgt
), "Source and target must contain the same number of examples"
assert len(src) == len(
ref
), "Source and reference must contain the same number of examples"
self.src = src
self.ref = ref
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.ref_sizes = np.array(ref_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)
self.src_dict = src_dict
self.ref_dict = ref_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert (
self.tgt_sizes is not None
), "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = eos if eos is not None else src_dict.eos()
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
num_buckets=num_buckets,
pad_idx=self.src_dict.pad(),
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
self.ref = BucketPadLengthDataset(
self.ref,
sizes=self.ref_sizes,
num_buckets=num_buckets,
pad_idx=self.ref_dict.pad(),
left_pad=self.left_pad_source,
)
self.ref_sizes = self.ref.sizes
logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
sizes=self.tgt_sizes,
num_buckets=num_buckets,
pad_idx=self.tgt_dict.pad(),
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
logger.info(
"bucketing target lengths: {}".format(list(self.tgt.buckets))
)
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
self.pad_to_multiple = pad_to_multiple
def get_batch_shapes(self):
return self.buckets
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
ref_item = self.ref[index]
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][0] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.ref[index][0] != bos:
ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
if self.ref[index][-1] == eos:
ref_item = self.ref[index][:-1]
example = {
"id": index,
"source": src_item,
"reference": ref_item,
"target": tgt_item,
}
if self.align_dataset is not None:
example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
def __len__(self):
return len(self.src)
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{'source': source_pad_to_length, 'target': target_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the left if *left_pad_source* is ``True``.
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
lengths of each source sentence of shape `(bsz)`
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one
position for teacher forcing, of shape `(bsz, tgt_len)`.
This key will not be present if *input_feeding* is
``False``. Padding will appear on the left if
*left_pad_target* is ``True``.
- `src_lang_id` (LongTensor): a long Tensor which contains source
language IDs of each sample in the batch
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
res = collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
pad_to_length=pad_to_length,
pad_to_multiple=self.pad_to_multiple,
)
if self.src_lang_id is not None or self.tgt_lang_id is not None:
src_tokens = res["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
if self.src_lang_id is not None:
res["net_input"]["src_lang_id"] = (
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
)
if self.tgt_lang_id is not None:
res["tgt_lang_id"] = (
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
)
return res
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
sizes = self.src_sizes[indices]
if self.tgt_sizes is not None:
sizes = np.maximum(sizes, self.tgt_sizes[indices])
return sizes
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
]
@property
def supports_prefetch(self):
return getattr(self.src, "supports_prefetch", False) and (
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
)
def prefetch(self, indices):
self.src.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)
def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment