Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from typing import Dict, List, Optional
import torch
from torch import Tensor
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
from fairseq.modules import LayerNorm
from speech2c.models.modules.multihead_attention import MultiheadAttention
class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
):
super().__init__(
cfg,
no_encoder_attn,
add_bias_kv,
add_zero_attn,
)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_bias=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.dataclass import ChoiceEnum
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
SamePad,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import index_put
from fairseq.distributed import fsdp_wrap
from fairseq.models.wav2vec.utils import pad_to_multiple
from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
from speech2c.models.modules.multihead_attention import MultiheadAttention
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
class TransformerEncoder(W2vTransformerEncoder):
def __init__(self, args):
super().__init__(args)
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.required_seq_len_multiple = args.required_seq_len_multiple
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
layers = []
for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=self.use_rel_pos_enc,
)
if args.checkpoint_activations:
layer = fsdp_wrap(layer)
layer = checkpoint_wrapper(layer)
layers.append(layer)
self.layers = nn.ModuleList(layers)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
if tgt_layer is not None:
# unpad if needed
if pad_length > 0:
layer_results.append(
(
x[:-pad_length],
z[:, :-pad_length, :-pad_length]
if z is not None
else z,
)
)
else:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
pos_bias=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import logging
import copy
import contextlib
from typing import Dict, List, Optional, Tuple
import torch
from dataclasses import dataclass, field
from fairseq.data.dictionary import Dictionary
from fairseq.models import register_model
from fairseq.models.hubert import HubertConfig, HubertModel
from fairseq.models.transformer import Embedding
from torch import Tensor
from speech2c.tasks.speech2c_pretraining import (
Speech2cPretrainingConfig,
Speech2cPretrainingTask,
)
from speech2c.models.modules.transformer_decoder import TransformerDecoderScriptable
from speech2c.models.modules.transformer_encoder import TransformerEncoder
logger = logging.getLogger(__name__)
@dataclass
class Speech2cConfig(HubertConfig):
use_rel_pos_enc: bool = field(
default=False,
metadata={"help": "whether to use relative positional encoding"},
)
# decoder
decoder_layers: int = field(
default=6, metadata={"help": "num decoder layers in the transformer"}
)
decoder_embed_dim: int = field(
default=768, metadata={"help": "decoder embedding dimension"}
)
decoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_attention_heads: int = field(
default=12, metadata={"help": "num decoder attention heads"}
)
decoder_normalize_before: bool = field(
default=False,
metadata={"help": "apply layernorm before each decoder block"},
)
decoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a tarnsformer layer"},
)
share_decoder_input_output_embed: bool = field(
default=False,
metadata={"help": "share decoder input and output embeddings"},
)
decoder_output_dim: int = field(
default=768, metadata={"help": "decoder output dimension"}
)
max_target_positions: int = field(
default=3000, metadata={"help": "max target position"}
)
no_scale_embedding: bool = field(
default=False,
metadata={"help": "not scale embedding"},
)
adaptive_input: bool = field(
default=False,
metadata={"help": "adaptive input"},
)
quant_noise_pq: int = field(
default=0, metadata={"help": "quant noise pq"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "decoder learnable positional embedding"},
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={"help": "no token positional embeddings"},
)
decoder_dict_size: int = field(
default=-1,
metadata={"help": "decoder dictionary dimension, only used for fine-tuning"},
)
# FP16 optimization
required_seq_len_multiple: int = field(
default=1,
metadata={
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
},
)
crop_seq_to_multiple: int = field(
default=1,
metadata={
"help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
},
)
@register_model("speech2c", dataclass=Speech2cConfig)
class Speech2cModel(HubertModel):
def __init__(
self,
cfg: Speech2cConfig,
task_cfg: Speech2cPretrainingConfig,
dictionaries: List[Dictionary],
) -> None:
super().__init__(cfg, task_cfg, dictionaries)
logger.info(f"Speech2cModel Config: {cfg}")
self.encoder = TransformerEncoder(cfg)
self.add_decoder = task_cfg.add_decoder
if task_cfg.add_decoder:
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size
cut_dictionary = copy.deepcopy(dictionaries[0])
if cfg.decoder_dict_size != -1:
cut_dictionary.symbols = cut_dictionary.symbols[:cfg.decoder_dict_size]
decoder_embed_tokens = build_embedding(
cut_dictionary, cfg.decoder_embed_dim
)
self.decoder = TransformerDecoderScriptable(cfg, cut_dictionary, decoder_embed_tokens)
@classmethod
def build_model(cls, cfg: Speech2cConfig, task: Speech2cPretrainingTask):
"""Build a new model instance."""
model = Speech2cModel(cfg, task.cfg, task.dictionaries)
return model
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
prev_output_tokens: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
if self.add_decoder:
encoder_out = {
"encoder_out": [x.transpose(0, 1)], # T x B x C
"encoder_padding_mask": [padding_mask], # B x T
}
assert prev_output_tokens is not None
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
result['decoder_out'] = decoder_out
return result
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
res = self.forward(
net_input["source"],
padding_mask=net_input["padding_mask"],
mask=False,
features_only=True
)
encoder_out = {
"encoder_out": [res["x"].transpose(0, 1)], # T x B x C
"encoder_padding_mask": [res["padding_mask"]], # B x T
}
return encoder_out
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
prev_output_tokens: Optional[torch.Tensor] = None,
ft: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad() if not ft else contextlib.ExitStack():
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
if self.add_decoder:
encoder_out = {
"encoder_out": [feature.transpose(0, 1)], # T x B x C
"encoder_padding_mask": [res["padding_mask"]], # B x T
}
assert prev_output_tokens is not None
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
)
else:
decoder_out = None
return feature, res["padding_mask"], decoder_out
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from argparse import Namespace
from omegaconf import II
import torch.nn as nn
from dataclasses import dataclass, field
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.hubert.hubert_asr import HubertAsrConfig, Linear
from fairseq.tasks import FairseqTask
@dataclass
class Speech2cAsrConfig(HubertAsrConfig):
# for decoder
decoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a decoder layer in hubert"},
)
add_decoder: bool = II("task.add_decoder")
@dataclass
class Speech2cCtcConfig(Speech2cAsrConfig):
pass
@register_model("speech2c_ctc", dataclass=Speech2cCtcConfig)
class Speech2cCtc(BaseFairseqModel):
def __init__(self, cfg: Speech2cCtcConfig, w2v_encoder: BaseFairseqModel):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: Speech2cCtcConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = Speech2cEncoder(cfg, task.target_dictionary)
return cls(cfg, w2v_encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if "encoder_out" not in net_output:
return self.w2v_encoder.get_normalized_probs_decoder(net_output, log_probs, sample)
if "encoder_out_for_ctc" in net_output:
logits = net_output["encoder_out_for_ctc"]
else:
logits = net_output["encoder_out"]
if isinstance(logits, list):
logits = logits[0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
@property
def encoder(self):
return self.w2v_encoder
def reorder_encoder_out(self, encoder_out, new_order):
return self.encoder.reorder_encoder_out(encoder_out, new_order)
@property
def decoder(self):
return self.w2v_encoder.w2v_model.decoder
class Speech2cEncoder(FairseqEncoder):
def __init__(self, cfg: Speech2cAsrConfig, tgt_dict=None):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"decoder_layerdrop": cfg.decoder_layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"decoder_dict_size": len(tgt_dict) if cfg.add_decoder else -1,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args.task.data = cfg.data
w2v_args.task.add_decoder = cfg.add_decoder
task = tasks.setup_task(w2v_args.task)
if state is not None and "task_state" in state:
# This will load the stored "dictionaries" object
task.load_state_dict(state["task_state"])
model = task.build_model(w2v_args.model)
if state is not None and not cfg.no_pretrained_weights:
if "decoder.embed_tokens.weight" in state["model"]:
del state["model"]["decoder.embed_tokens.weight"]
if "decoder.output_projection.weight" in state["model"]:
del state["model"]["decoder.output_projection.weight"]
# set strict=False because we omit some modules
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(task.source_dictionary)
d = model.mask_emb.size(0)
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
else:
self.proj = None
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
ft = self.freeze_finetune_updates <= self.num_updates
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
"prev_output_tokens": prev_output_tokens,
"ft": ft,
}
x, padding_mask, decoder_out = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self.final_dropout(x)
if self.proj:
x = self.proj(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
"decoder_out": decoder_out,
}
def get_normalized_probs_decoder(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
return self.w2v_model.get_normalized_probs(net_output, log_probs, sample)
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
if isinstance(encoder_out["encoder_out"], list):
encoder_out["encoder_out"] = (
[] if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
else:
encoder_out["encoder_out"] = encoder_out[
"encoder_out"
].index_select(1, new_order)
if encoder_out["encoder_padding_mask"] is not None:
if isinstance(encoder_out["encoder_padding_mask"], list):
encoder_out["encoder_padding_mask"] = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
)
else:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if "decoder_out" in encoder_out and encoder_out["decoder_out"] is not None:
if isinstance(encoder_out["decoder_out"], list):
encoder_out["decoder_out"] = (
[] if len(encoder_out["decoder_out"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["decoder_out"]]
)
else:
encoder_out["decoder_out"] = encoder_out[
"decoder_out"
].index_select(0, new_order)
if "encoder_out_for_ctc" in encoder_out and encoder_out["encoder_out_for_ctc"] is not None:
if isinstance(encoder_out["encoder_out_for_ctc"], list):
encoder_out["encoder_out_for_ctc"] = (
[] if len(encoder_out["encoder_out_for_ctc"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out_for_ctc"]]
)
else:
encoder_out["encoder_out_for_ctc"] = encoder_out[
"encoder_out_for_ctc"
].index_select(1, new_order)
return encoder_out
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
encoder_out = self.w2v_model.forward_torchscript(net_input)
assert self.proj is not None
encoder_out['encoder_out_for_ctc'] = [self.proj(encoder_out['encoder_out'][0])]
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from fairseq.models import (
register_model_architecture,
)
from fairseq.models.transformer_lm import base_lm_architecture
@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
def transformer_lm_t5(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
args.decoder_layers = getattr(args, "decoder_layers", 20)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from . import data, tasks, criterions, models
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
importlib.import_module(
"speechut.criterions." + criterion_name
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment