Commit 112bf76b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1826 canceled with stages
import torch
import torch.nn as nn
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
from vita.util.s2wrapper import forward as multiscale_forward
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = -2
if not delay_load:
self.load_model()
else:
self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.image_processor.crop_size = self.image_processor.size
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class SiglipVisionTowerS2(SiglipVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
self.s2_scales = getattr(args, "s2_scales", "384,768,1152")
self.s2_scales = list(map(int, self.s2_scales.split(",")))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]
super().__init__(vision_tower, args, delay_load)
self.multiscale_forward = multiscale_forward
if not delay_load:
self.image_processor.size["height"] = self.image_processor.size[
"width"
] = self.s2_image_size
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
"width"
] = self.s2_image_size
def load_model(self):
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.image_processor.crop_size = self.image_processor.size
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.image_processor.size["height"] = self.image_processor.size[
"width"
] = self.s2_image_size
self.image_processor.crop_size["height"] = self.image_processor.crop_size[
"width"
] = self.s2_image_size
self.is_loaded = True
@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(
self.forward_feature,
image.unsqueeze(0),
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size,
)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(
self.forward_feature,
images,
img_sizes=self.s2_scales,
max_split_size=self.s2_split_size,
)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
class CNNAdapter(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 5,
):
super().__init__()
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu1 = nn.ReLU()
self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu2 = nn.ReLU()
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
def forward(self, x, mask_pad):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x = x.transpose(1, 2) # B, channels, T
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
x = self.left_padding1(x)
x = self.conv1d1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.left_padding2(x)
x = self.conv1d2(x)
x = self.bn2(x)
x = self.relu2(x)
x = x.transpose(1, 2)
x = self.project(x)
return x, mask_pad
class LinearAdapter(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
):
super().__init__()
self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)
def forward(self, x, mask_pad):
return self.adpter(x), mask_pad
class CNNSubsampling(torch.nn.Module):
def __init__(
self,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 5,
activation_func: str = "relu",
norm: str = "batch",
):
super().__init__()
if enc_out_dim * 4 < llm_embed_dim:
self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu1 = nn.ReLU()
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
self.relu2 = nn.ReLU()
self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
self.cnn_num = 2
else:
self.left_padding2 = nn.ConstantPad1d((0, kernel_size - 1), 0.0)
self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
if norm == "batch":
self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
elif norm == "layer":
self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
if activation_func == "gelu":
self.relu2 = nn.GELU()
else:
self.relu2 = nn.ReLU()
self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
self.cnn_num = 1
def forward(self, x, mask_pad):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x = x.transpose(1, 2) # B, channels, T
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.cnn_num == 2:
x = self.left_padding1(x)
x = self.conv1d1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.left_padding2(x)
x = self.conv1d2(x)
if isinstance(self.bn2, nn.LayerNorm):
x = x.transpose(1, 2)
x = self.bn2(x)
if isinstance(self.bn2, nn.LayerNorm):
x = x.transpose(1, 2)
x = self.relu2(x)
x = x.transpose(1, 2)
x = self.project(x)
return x, mask_pad[:, :, 0::2]
import numpy as np
import torch
import json
import math
class GlobalCMVN(torch.nn.Module):
def __init__(self, mean: torch.Tensor, istd: torch.Tensor, norm_var: bool = True):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super().__init__()
assert mean.shape == istd.shape
self.norm_var = norm_var
# The buffer can be accessed from this module using self.mean
self.register_buffer("mean", mean)
self.register_buffer("istd", istd)
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
def load_cmvn_json(json_cmvn_file):
with open(json_cmvn_file) as f:
cmvn_json = json.load(f)
avg = cmvn_json["mean_stat"]
var = cmvn_json["var_stat"]
count = cmvn_json["frame_num"]
for i in range(len(avg)):
avg[i] /= count
var[i] = var[i] / count - avg[i] * avg[i]
if var[i] < 1.0e-20:
var[i] = 1.0e-20
var[i] = 1.0 / math.sqrt(var[i])
cmvn = np.array([avg, var])
return cmvn
def load_cmvn_kaldi(kaldi_cmvn_file):
avg = []
var = []
with open(kaldi_cmvn_file, "r") as file:
# kaldi binary file start with '\0B'
if file.read(2) == "\0B":
logging.error(
"kaldi cmvn binary file is not supported, please "
)
sys.exit(1)
file.seek(0)
arr = file.read().split()
assert arr[0] == "["
assert arr[-2] == "0"
assert arr[-1] == "]"
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
avg.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
var.append(float(arr[i]))
for i in range(len(avg)):
avg[i] /= count
var[i] = var[i] / count - avg[i] * avg[i]
if var[i] < 1.0e-20:
var[i] = 1.0e-20
var[i] = 1.0 / math.sqrt(var[i])
cmvn = np.array([avg, var])
return cmvn
def load_cmvn(filename, is_json):
if is_json:
file = load_cmvn_json(filename)
else:
file = load_cmvn_kaldi(filename)
return file[0], file[1]
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter
from .cmvn import GlobalCMVN, load_cmvn
from .module.encoder.encoder import whaleEncoder
class audioEncoderProcessor:
def __init__(
self,
dataset_conf: dict = None,
):
self.dataset_conf = dataset_conf
def process(self, wav_path):
try:
print("#################", wav_path)
waveform, sample_rate = torchaudio.load(wav_path)
except Exception as e:
print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!")
if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]:
# sample_rate = self.dataset_conf['resample_conf']['resample_rate']
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"]
)(waveform)
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(
waveform,
num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"],
frame_length=self.dataset_conf["fbank_conf"]["frame_length"],
frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"],
dither=self.dataset_conf["fbank_conf"]["dither"],
energy_floor=0.0,
sample_frequency=sample_rate,
)
attn_mask = torch.ones(mat.shape[0])
attn_mask = attn_mask[2::2][2::2][0::2]
return mat, attn_mask.shape[0]
class audioEncoder(torch.nn.Module):
def __init__(
self,
encoder: torch.nn.Module,
llm_path: str,
freeze_llm: bool = True,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 3,
IGNORE_ID: int = -100,
adpter_type: str = "cnn",
add_audio_bos_eos: bool = False,
task_num: int = 10,
task_before_audio: bool = False,
task_type: str = "prompt",
freeze_encoder: bool = False,
freeze_adpter: bool = False,
activation_func: str = "relu",
norm: str = "batch",
chat_template=None,
):
super().__init__()
self.encoder = encoder
self.enc_out_dim = enc_out_dim
self.llm_embed_dim = llm_embed_dim
self.IGNORE_ID = IGNORE_ID
self.add_audio_bos_eos = add_audio_bos_eos
self.task_before_audio = task_before_audio
self.task_type = task_type
self.freeze_encoder = freeze_encoder
self.freeze_adpter = freeze_adpter
if adpter_type == "cnn":
self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
elif adpter_type == "linear":
self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
elif adpter_type == "subsampling":
self.adpter = CNNSubsampling(
enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm
)
if self.freeze_encoder:
self.encoder.eval()
for (name, param) in self.encoder.named_parameters():
param.requires_grad = False
if self.freeze_adpter:
self.adpter.eval()
for (name, param) in self.adpter.named_parameters():
param.requires_grad = False
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Dict[str, Optional[torch.Tensor]]:
speech = speech.to(next(self.parameters()).dtype)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D
attention_mask = encoder_mask.squeeze(1) # B, T
assert inputs_embeds.size(1) == attention_mask.size(1)
# audio bos/eos
if self.add_audio_bos_eos:
inputs_embeds, attention_mask, target = self._add_bos_eos(
"audio", "/audio", inputs_embeds, attention_mask, target
)
outputs = {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
}
return outputs
def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
B = len(inputs_embeds)
bos_embed = self.task_embeddings(
torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device)
) # B, 1, D
eos_embed = self.task_embeddings(
torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device)
) # B, 1, D
bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2
bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
if target is not None:
target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D
return inputs_embeds, attention_mask, target
def init_model(configs):
if configs["cmvn_file"] is not None:
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
else:
global_cmvn = None
input_dim = configs["input_dim"]
encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"])
model = audioEncoder(encoder=encoder, **configs["model_conf"])
processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"])
model.audio_processor = processor
return model
"""Encoder self-attention layer definition."""
import math
import pdb
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, strtobool
try:
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.models.mixer_seq_simple import _init_weights
from mamba_ssm.ops.triton.layernorm import RMSNorm
except ImportError:
print("Please install mamba_ssm to use MambaSSM component.")
class MambaBlock(nn.Module):
def __init__(self, in_channels, n_layer=1, d_state=16, d_conv=4, expand=4, bidirectional=False):
super(MambaBlock, self).__init__()
self.forward_blocks = nn.ModuleList([])
self.forward_norm_f = RMSNorm(in_channels, eps=1e-5)
for i in range(n_layer):
self.forward_blocks.append(
Block(
in_channels,
mixer_cls=partial(
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
),
norm_cls=partial(RMSNorm, eps=1e-5),
fused_add_norm=True,
residual_in_fp32=True,
)
)
if bidirectional:
self.backward_blocks = nn.ModuleList([])
for i in range(n_layer):
self.backward_blocks.append(
Block(
in_channels,
mixer_cls=partial(
Mamba, layer_idx=i, d_state=d_state, d_conv=d_conv, expand=expand
),
norm_cls=partial(RMSNorm, eps=1e-5),
fused_add_norm=True,
residual_in_fp32=True,
)
)
self.backward_norm_f = RMSNorm(in_channels, eps=1e-5)
else:
self.backward_blocks = None
self.apply(partial(_init_weights, n_layer=n_layer))
def forward(self, input):
for_residual = None
forward_f = input.clone()
for block in self.forward_blocks:
forward_f, for_residual = block(forward_f, for_residual, inference_params=None)
residual = (forward_f + for_residual) if for_residual is not None else forward_f
residual = self.forward_norm_f(residual)
if self.backward_blocks is not None:
back_residual = None
backward_f = torch.flip(input, [1])
for block in self.backward_blocks:
backward_f, back_residual = block(backward_f, back_residual, inference_params=None)
back_residual = (
(backward_f + back_residual) if back_residual is not None else backward_f
)
back_residual = torch.flip(back_residual, [1])
back_residual = self.backward_norm_f(back_residual)
residual = torch.cat([residual, back_residual], -1)
return residual
class MambaSSM(torch.nn.Module):
@staticmethod
def add_arguments(group):
"""Add TDNN common arguments."""
group.add_argument(
"--mamba-num-layers", default=4, type=int, help="Output dim of MambaSSM."
)
group.add_argument(
"--mamba-input-dim", default=256, type=int, help="Input dim of MambaSSM."
)
group.add_argument(
"--mamba-output-dim", default=256, type=int, help="Output dim of MambaSSM."
)
group.add_argument("--mamba-d-state", default=16, type=int, help="d-state of MambaSSM.")
group.add_argument("--mamba-d-conv", default=4, type=int, help="d-conv of MambaSSM.")
group.add_argument("--mamba-expand", default=4, type=int, help="expand of MambaSSM.")
return group
def __init__(self, args):
"""Construct an Encoder object."""
super(MambaSSM, self).__init__()
self.mamb_num_layers = args.mamba_num_layers
self.mamba_input_dim = args.mamba_input_dim
self.mamba_output_dim = args.mamba_output_dim
self.mamba_d_state = args.mamba_d_state
self.mamba_d_conv = args.mamba_d_conv
self.mamba_expand = args.mamba_expand
self.mamba = MambaBlock(
self.mamba_input_dim,
self.mamb_num_layers,
self.mamba_d_state,
self.mamba_d_conv,
self.mamba_expand,
)
@torch.jit.unused
def forward(self, xs, ilens=None, masks=None):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
xs_out = self.mamba(xs)
return xs_out.to(xs.dtype), ilens, masks
import torch
from typing import Tuple, Union
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.subsampling_rate = 1
self.right_context = 0
def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.right_context = 6
self.subsampling_rate = 4
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x, x_mask[:, :, 2::2][:, :, 2::2]
class Subsampling(torch.nn.Module):
@staticmethod
def add_arguments(group):
"""Add Subsampling common arguments."""
group.add_argument("--subsampling-rate", default=4, type=int)
group.add_argument("--subsampling-input-dim", default=256, type=int)
group.add_argument("--subsampling-output-dim", default=256, type=int)
group.add_argument("--subsampling-dropout-rate", default=0.1, type=float)
return group
def __init__(self, args):
super().__init__()
self.subsampling_rate = args.subsampling_rate
self.subsampling_input_dim = args.subsampling_input_dim
self.subsampling_output_dim = args.subsampling_output_dim
self.subsampling_dropout_rate = args.subsampling_dropout_rate
if self.subsampling_rate == 4:
self.core = Conv2dSubsampling4(
self.subsampling_input_dim,
self.subsampling_output_dim,
self.subsampling_dropout_rate,
)
def forward(self, xs, ilens, masks):
xs, masks = self.core(xs, masks)
ilens = masks.squeeze(1).sum(1)
return xs, ilens, masks
"""Encoder self-attention layer definition."""
import math
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vita.model.multimodal_encoder.whale.module.layer.attention import (
Conv1dLinear,
MultiHeadedAttention,
MultiLayeredConv1d,
PositionalEncoding,
PositionwiseFeedForward,
RelPositionalEncoding,
)
# from vita.model.multimodal_encoder.whale.module.component.utils import *
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool
def repeat(N, fn):
"""Repeat module N times.
:param int N: repeat time
:param function fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[fn(n) for n in range(N)])
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def forward(self, x, masks, pos_emb):
"""Repeat."""
for m in self:
x, masks, pos_emb = m(x, masks, pos_emb)
return x, masks, pos_emb
@torch.jit.export
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for m in self:
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
x, pos_emb, buffer, buffer_index, buffer_out
)
return x, pos_emb, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for m in self:
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
x, pos_emb, buffer, buffer_index, buffer_out
)
hidden_out.append(x)
return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out
class TransformerLayer(nn.Module):
"""Transformer layer module.
:param int size: input dim
:param self_attn: self attention module
:param feed_forward: feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False
):
"""Construct an TransformerLayer object."""
super(TransformerLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = torch.nn.LayerNorm(size)
self.norm2 = torch.nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
@torch.jit.unused
def forward(self, x, mask, pos_emb):
"""Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, pos_emb
@torch.jit.export
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
residual = x.clone()
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
x, x, x, pos_emb, buffer, buffer_index, buffer_out
)
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
x, x, x, pos_emb, buffer, buffer_index, buffer_out
)
x = residual + x_att
if not self.normalize_before:
x = self.norm1(x)
residual = x.clone()
if self.normalize_before:
x = self.norm2(x)
x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(
x, buffer, buffer_index, buffer_out
)
x = residual + x_feed
if not self.normalize_before:
x = self.norm2(x)
return x, pos_emb, buffer, buffer_index, buffer_out
class Transformer(torch.nn.Module):
@staticmethod
def add_arguments(group):
"""Add TDNN common arguments."""
group.add_argument(
"--transformer-input-dim", default=256, type=int, help="Input dim of Transformer."
)
group.add_argument(
"--transformer-output-dim", default=4, type=int, help="Output dim of Transformer."
)
group.add_argument(
"--transformer-attention-dim", default=256, type=int, help="Dimention of attention."
)
group.add_argument(
"--transformer-attention-heads",
default=4,
type=int,
help="The number of heads of multi head attention.",
)
group.add_argument(
"--transformer-linear-units",
default=1024,
type=int,
help="The number of units of position-wise feed forward.",
)
group.add_argument(
"--transformer-num-blocks", default=6, type=int, help="The number of attention blocks."
)
group.add_argument(
"--transformer-dropout-rate",
default=0.1,
type=float,
help="Dropout rate in Transformer.",
)
group.add_argument(
"--transformer-attention-dropout-rate",
default=0.0,
type=float,
help="Dropout rate in attention.",
)
group.add_argument(
"--transformer-positional-dropout-rate",
default=0.1,
type=float,
help="Dropout rate after adding positional encoding.",
)
group.add_argument(
"--transformer-input-layer", default="linear", type=str, help="Type of input layer"
)
group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="")
group.add_argument(
"--transformer-normalize-before",
default=True,
type=strtobool,
help="Whether to use layer-norm before the first block.",
)
group.add_argument(
"--transformer-concat-after",
default=False,
type=strtobool,
help="Whether to concat attention layer's input and output.",
)
group.add_argument(
"--transformer-positionwise-layer-type",
default="linear",
type=str,
help="Linear of conv1d.",
)
group.add_argument(
"--transformer-positionwise-conv-kernel_size",
default=1,
type=int,
help="Kernel size of positionwise conv1d layer.",
)
group.add_argument("--transformer-chunk_size", default=-1, type=int, help="")
group.add_argument("--transformer-left_chunks", default=-1, type=int, help="")
group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="")
return group
def __init__(
self,
args,
input_dim=None,
output_dim=None,
attention_dim=None,
attention_heads=None,
linear_units=None,
num_blocks=None,
dropout_rate=None,
positional_dropout_rate=None,
attention_dropout_rate=None,
input_layer=None,
pos_enc_class=None,
normalize_before=None,
concat_after=None,
positionwise_layer_type=None,
positionwise_conv_kernel_size=None,
chunk_size=None,
left_chunks=None,
):
"""Construct an Encoder object."""
super(Transformer, self).__init__()
if args is None:
self.input_dim = input_dim
self.output_dim = output_dim
self.attention_dim = attention_dim
self.attention_heads = attention_heads
self.linear_units = linear_units
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.positional_dropout_rate = positional_dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.input_layer = input_layer
self.pos_enc_class = pos_enc_class
self.normalize_before = normalize_before
self.concat_after = concat_after
self.positionwise_layer_type = positionwise_layer_type
self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
self.chunk_size = chunk_size
self.left_chunks = left_chunks
else:
self.input_dim = args.transformer_input_dim
self.output_dim = args.transformer_output_dim
self.attention_dim = args.transformer_attention_dim
self.attention_heads = args.transformer_attention_heads
self.linear_units = args.transformer_linear_units
self.num_blocks = args.transformer_num_blocks
self.dropout_rate = args.transformer_dropout_rate
self.positional_dropout_rate = args.transformer_positional_dropout_rate
self.attention_dropout_rate = args.transformer_attention_dropout_rate
self.input_layer = args.transformer_input_layer
self.pos_enc_class = args.transformer_pos_enc_class
self.normalize_before = args.transformer_normalize_before
self.concat_after = args.transformer_concat_after
self.positionwise_layer_type = args.transformer_positionwise_layer_type
self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
self.chunk_size = args.transformer_chunk_size
self.left_chunks = args.transformer_left_chunks
self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
if self.pos_enc_class == "abs-enc":
pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
pos_enc_class = PositionalEncoding
elif self.pos_enc_class == "rel-enc":
pos_enc_args = (
self.attention_dim,
self.positional_dropout_rate,
self.chunk_size,
self.left_chunks,
)
pos_enc_class = RelPositionalEncoding
if self.input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(self.input_dim, self.attention_dim),
torch.nn.LayerNorm(self.attention_dim),
torch.nn.Dropout(self.dropout_rate),
torch.nn.ReLU(),
)
elif self.input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
)
elif self.input_layer == "none":
self.embed = torch.nn.Sequential(torch.nn.Identity())
else:
raise ValueError("unknown input_layer: " + self.input_layer)
self.pe = pos_enc_class(*pos_enc_args)
self.embed_layer_num = len(self.embed)
if self.positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
elif self.positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
self.attention_dim,
self.linear_units,
self.positionwise_conv_kernel_size,
self.dropout_rate,
)
elif self.positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
self.attention_dim,
self.linear_units,
self.positionwise_conv_kernel_size,
self.dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
self.num_blocks,
lambda lnum: TransformerLayer(
self.attention_dim,
MultiHeadedAttention(
self.attention_heads,
self.attention_dim,
self.attention_dropout_rate,
self.chunk_size,
self.left_chunks,
self.pos_enc_class,
),
positionwise_layer(*positionwise_layer_args),
self.dropout_rate,
self.normalize_before,
self.concat_after,
),
)
if self.normalize_before:
self.after_norm = torch.nn.LayerNorm(self.attention_dim)
@torch.jit.unused
def forward(self, xs, ilens=None, masks=None):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if self.transformer_dynamic_chunks == True: # and self.training:
chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1)
else:
chunk_masks = add_optional_chunk_mask(
xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks
).to(xs.device)
xs = self.embed(xs)
xs, pos_emb = self.pe(xs)
xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, ilens, masks
@torch.jit.export
def infer(self, xs, buffer, buffer_index, buffer_out):
xs = self.embed(xs)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs, pos_emb, _ = self.pe.infer(xs, 0)
xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(
xs, pos_emb, buffer, buffer_index, buffer_out
)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out):
xs = self.embed(xs)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs, pos_emb, _ = self.pe.infer(xs, 0)
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden(
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out
)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, buffer, buffer_index, buffer_out, hidden_out
import argparse
import logging
import sys
import time
from typing import Dict, Optional, Tuple
import numpy as np
import six
import torch
from vita.model.multimodal_encoder.whale.module.component.mamba import MambaSSM
from vita.model.multimodal_encoder.whale.module.component.subsampling import Subsampling
from vita.model.multimodal_encoder.whale.module.component.transformer import Transformer
from vita.model.multimodal_encoder.whale.utils import make_pad_mask
def add_encoder_args(group):
"""Add Encoder common arguments."""
group.add_argument(
"--encoder-layer-config",
type=str,
default="tdnn-dtc",
help="Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)",
)
group.add_argument(
"--encoder-input-dim",
type=int,
default=256,
help="Input dim of encoder. Must equal to the input dim of the first Component (default=40)",
)
group.add_argument(
"--encoder-output-dim",
type=int,
default=256,
help="Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)",
)
# Add args of all kinds of components.
# If you add a new component, DO NOT forget to add args to add_component_args func.
group = Transformer.add_arguments(group)
group = Subsampling.add_arguments(group)
group = MambaSSM.add_arguments(group)
return group
def assign_args_from_dict(args, dict, prefix_key=None):
if prefix_key is not None:
dict = dict[prefix_key]
for k, v in dict.items():
k_args = k.replace("-", "_")
if hasattr(args, k_args):
setattr(args, k_args, dict[k])
return args
class whaleEncoder(torch.nn.Module):
def __init__(self, input_dim, overview_conf=None, para_conf=None, global_cmvn=None):
super(whaleEncoder, self).__init__()
parser = argparse.ArgumentParser()
add_encoder_args(parser)
args, _ = parser.parse_known_args()
assign_args_from_dict(args, overview_conf)
# assign_args_from_dict(args, para_conf)
self.config = args.encoder_layer_config.split("-")
encoder_input_dim = args.encoder_input_dim
encoder_output_dim = args.encoder_output_dim
prev_output_dim = encoder_input_dim
prev_component_name = "encoder"
self.enc = torch.nn.ModuleList([])
for name in self.config:
assign_args_from_dict(args, para_conf[name])
if len(name.split("_")) == 2:
name = name.split("_")[0]
elif len(name.split("_")) == 1:
name = name
else:
logging.error("WRONG CONFIG! {} is not valid".format("encoder", name))
sys.exit()
if name == "transformer":
self.enc.append(Transformer(args))
elif name == "subsampling":
self.enc.append(Subsampling(args))
elif name == "mamba":
self.enc.append(MambaSSM(args))
else:
print("{} is not supported now!".format(name))
return NotImplemented
component_input_dim = getattr(args, name + "_input_dim")
if component_input_dim != prev_output_dim:
# This is the first layer
logging.error(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})".format(
prev_component_name, prev_output_dim, name, component_input_dim
)
)
sys.exit()
prev_output_dim = getattr(args, name + "_output_dim")
prev_component_name = name
self.global_cmvn = global_cmvn
if prev_output_dim != encoder_output_dim:
logging.error(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)".format(
"encoder", encoder_output_dim, name, prev_output_dim
)
)
sys.exit()
self._output_size = encoder_output_dim
num_params = sum(p.numel() for p in self.parameters())
print("the number of whale encoder params: {}M".format(num_params / 1024 / 1024))
def output_size(self) -> int:
return self._output_size
@torch.jit.unused
def forward(self, xs, ilens, decoding_chunk_size=None, num_decoding_left_chunks=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if decoding_chunk_size is not None and num_decoding_left_chunks is not None:
for layer in self.enc:
if hasattr(layer, "chunk_size"):
layer.chunk_size = decoding_chunk_size
if hasattr(layer, "left_chunks"):
layer.left_chunks = num_decoding_left_chunks
if hasattr(layer, "transformer_dynamic_chunks"):
layer.transformer_dynamic_chunks = False
assert (len(xs.shape)) == 3
T = xs.size(1)
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs, ilens, masks = module(xs, ilens, masks)
return xs, masks
@torch.jit.export
def infer(self, xs_pad, buffer, buffer_index, buffer_out):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs_pad, buffer, buffer_index, buffer_out = module.infer(
xs_pad, buffer, buffer_index, buffer_out
)
return xs_pad, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, xs_pad, buffer, buffer_index, buffer_out, hidden_out):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
for module in self.enc:
xs_pad, buffer, buffer_index, buffer_out, hidden_out = module.infer_hidden(
xs_pad, buffer, buffer_index, buffer_out, hidden_out
)
return xs_pad, buffer, buffer_index, buffer_out, hidden_out
@torch.jit.ignore(drop=True)
def get_extra_loss(self) -> Dict[str, torch.Tensor]:
return None
import math
import pdb
import numpy
import torch
import torch.nn as nn
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 1500, reverse: bool = False
):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
self.pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
self.pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0)
def forward(self, x: torch.Tensor, offset: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
assert offset + x.size(1) < self.max_len
self.pe = self.pe.to(x.device)
pos_emb = self.pe[:, offset : offset + x.size(1)]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int):
"""For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
assert offset + size < self.max_len
return self.dropout(self.pe[:, offset : offset + size])
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(
self,
d_model: int,
dropout_rate: float,
chunk_size: int,
left_chunks: int,
max_len: int = 5000,
):
"""Initialize class."""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.chunk_size = chunk_size
self.left_chunks = left_chunks
self.full_chunk_size = (self.left_chunks + 1) * self.chunk_size
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
self.max_len = self.chunk_size * (max_len // self.chunk_size) - self.full_chunk_size
@torch.jit.export
def forward(self, x: torch.Tensor, offset: int = 0):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.pe[:, offset : offset + x.size(1)]
return self.dropout(x), self.dropout(pos_emb)
@torch.jit.export
def infer(self, xs, pe_index):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
pe_index = pe_index % self.max_len
xs = xs * self.xscale
pe = torch.zeros(self.full_chunk_size, self.d_model)
position = torch.arange(
pe_index, pe_index + self.full_chunk_size, dtype=torch.float32
).unsqueeze(1)
pe[:, 0::2] = torch.sin(position * self.div_term)
pe[:, 1::2] = torch.cos(position * self.div_term)
pos_emb = pe.unsqueeze(0)
pe_index = pe_index + self.chunk_size
return xs, pos_emb, pe_index
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
:param int idim: input dimenstion
:param int hidden_units: number of hidden units
:param float dropout_rate: dropout rate
"""
def __init__(self, idim, hidden_units, dropout_rate):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Forward funciton."""
return self.w_2(self.dropout(torch.relu(self.w_1(x))))
@torch.jit.export
def infer(self, xs, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
return self.w_2(torch.relu(self.w_1(xs))), buffer, buffer_index, buffer_out
class MultiLayeredConv1d(torch.nn.Module):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed
to replace positionwise feed-forward network
in Transformer block, which is introduced in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(MultiLayeredConv1d, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Conv1d(
hidden_chans,
in_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.dropout = torch.nn.Dropout(dropout_rate)
@torch.jit.unused
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(Conv1dLinear, self).__init__()
self.lorder = kernel_size - 1
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
self.w_1 = torch.nn.Sequential(
torch.nn.Conv1d(in_chans, in_chans, kernel_size, stride=1, padding=0, groups=in_chans),
torch.nn.Conv1d(in_chans, hidden_chans, 1, padding=0),
)
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
self.in_chans = in_chans
# cnn_buffer = 1, in_chans, self.lorder
self.buffer_size = 1 * self.in_chans * self.lorder
@torch.jit.unused
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(self.left_padding(x.transpose(-1, 1)))).transpose(-1, 1)
return self.w_2(self.dropout(x))
@torch.jit.export
def infer(self, x, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
x = x.transpose(-1, 1)
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
[1, self.in_chans, self.lorder]
)
x = torch.cat([cnn_buffer, x], dim=2)
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
buffer_index = buffer_index + self.buffer_size
x = self.w_1(x)
x = torch.relu(x).transpose(-1, 1)
x = self.w_2(x)
return x, buffer, buffer_index, buffer_out
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate, chunk_size, left_chunks, pos_enc_class):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
# self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
self.min_value = float(torch.finfo(torch.float16).min)
# chunk par
if chunk_size > 0 and left_chunks > 0: # for streaming mode
self.buffersize = chunk_size * (left_chunks)
self.left_chunk_size = chunk_size * left_chunks
else: # for non-streaming mode
self.buffersize = 1
self.left_chunk_size = 1
self.chunk_size = chunk_size
# encoding setup
if pos_enc_class == "rel-enc":
self.rel_enc = True
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
else:
self.rel_enc = False
self.linear_pos = nn.Identity()
self.pos_bias_u = torch.tensor([0])
self.pos_bias_v = torch.tensor([0])
# buffer
# key_buffer = 1, self.h, self.buffersize, self.d_k
self.key_buffer_size = 1 * self.h * self.buffersize * self.d_k
# value_buffer = 1, self.h, self.buffersize, self.d_k
self.value_buffer_size = 1 * self.h * self.buffersize * self.d_k
if self.chunk_size > 0:
# buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
self.buffer_mask_size = 1 * self.h * self.chunk_size * self.buffersize
# self.buffer_mask = torch.ones([1, self.h, self.chunk_size, self.buffersize], dtype=torch.bool)
else:
self.buffer_mask = torch.ones([1, self.h, 1, 1], dtype=torch.bool)
@torch.jit.unused
def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros(
(x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
@torch.jit.export
def forward(self, query, key, value, mask=None, pos_emb=torch.tensor(1.0)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
if self.rel_enc:
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb.to(query.dtype)).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, self.min_value)
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
@torch.jit.export
def infer(self, query, key, value, pos_emb, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
n_batch = query.size(0)
q = (
self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_q, d_k)
k = (
self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_k, d_k)
v = (
self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_v, d_k)
key_value_buffer = buffer[
buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
].reshape([1, self.h, self.buffersize * 2, self.d_k])
key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
buffer_out.append(
torch.cat(
[key_buffer[:, :, self.chunk_size :, :], value_buffer[:, :, self.chunk_size :, :]],
dim=2,
).reshape(-1)
)
buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
if self.rel_enc:
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, key_buffer.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
else:
scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
self.d_k
) # (batch, head, len_q, buffersize)
attn = torch.softmax(scores, dim=-1)
x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x), buffer, buffer_index, buffer_out # (batch, time1, d_model)
@torch.jit.export
def infer_mask(self, query, key, value, mask, buffer, buffer_index, buffer_out, is_static):
n_batch = query.size(0)
q = (
self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_q, d_k)
k = (
self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_k, d_k)
v = (
self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
) # (batch, head, len_v, d_k)
if is_static:
key_buffer = k
value_buffer = v
else:
key_value_buffer = buffer[
buffer_index : buffer_index + self.key_buffer_size + self.value_buffer_size
].reshape([1, self.h, self.buffersize * 2, self.d_k])
key_buffer = torch.cat([key_value_buffer[:, :, : self.buffersize, :], k], dim=2)
value_buffer = torch.cat([key_value_buffer[:, :, self.buffersize :, :], v], dim=2)
buffer_out.append(
torch.cat(
[
key_buffer[:, :, self.chunk_size :, :],
value_buffer[:, :, self.chunk_size :, :],
],
dim=2,
).reshape(-1)
)
buffer_index = buffer_index + self.key_buffer_size + self.value_buffer_size
scores = torch.matmul(q, key_buffer.transpose(-2, -1)) / math.sqrt(
self.d_k
) # (batch, head, len_q, buffersize)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, self.min_value)
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
x = torch.matmul(attn, value_buffer) # (batch, head, len_q, d_k)
x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x), buffer_index, buffer_out # (batch, time1, d_model)
class SoftAttention(nn.Module):
def __init__(self, in_dim, hidden_dim):
super(SoftAttention, self).__init__()
self.q = torch.nn.Parameter(torch.rand([hidden_dim]), requires_grad=True)
self.wb = nn.Linear(in_dim, hidden_dim)
self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min)
# buffer
self.window_size = 50
self.buffer_in = torch.zeros([1, self.window_size, in_dim], dtype=torch.float32)
self.buffer = torch.zeros([1, self.window_size], dtype=torch.float32)
self.buffer[:, :] = float(
numpy.finfo(torch.tensor(0, dtype=torch.float32).numpy().dtype).min
)
@torch.jit.unused
def forward(self, x, mask=None):
hidden = torch.tanh(self.wb(x)) # B T D
hidden = torch.einsum("btd,d->bt", hidden, self.q)
score = torch.softmax(hidden, dim=-1) # B T
if mask is not None:
score = score.masked_fill(mask, 0.0)
output = torch.einsum("bt,btd->bd", score, x)
return output
@torch.jit.export
def infer(self, x):
# type: (Tensor) -> Tensor
hidden = torch.tanh(self.wb(x)) # B T D
hidden = torch.einsum("btd,d->bt", hidden, self.q)
size = hidden.shape[1]
output = torch.zeros([size, x.shape[-1]])
for i in range(size):
self.buffer = torch.cat([self.buffer, hidden[:, i : i + 1]], dim=-1)
self.buffer = self.buffer[:, 1:]
score = torch.softmax(self.buffer, dim=-1) # B T
self.buffer_in = torch.cat([self.buffer_in, x[:, i : i + 1, :]], dim=1)
self.buffer_in = self.buffer_in[:, 1:]
output[i : i + 1] = torch.einsum("bt,btd->bd", score, self.buffer_in)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv1dLayer(nn.Module):
def __init__(
self,
input_dim,
output_dim,
kernel_size,
stride,
causal_conv,
dilation,
dropout_rate,
residual=True,
):
super(Conv1dLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.causal_conv = causal_conv
if causal_conv:
self.lorder = (kernel_size - 1) * self.dilation
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
else:
assert (kernel_size - 1) % 2 == 0
self.lorder = ((kernel_size - 1) // 2) * self.dilation
self.left_padding = nn.ConstantPad1d((self.lorder, self.lorder), 0.0)
self.conv1d = nn.Conv1d(
self.input_dim, self.output_dim, self.kernel_size, self.stride, 0, self.dilation
)
self.bn = nn.BatchNorm1d(self.output_dim, eps=1e-3, momentum=0.99)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=dropout_rate)
self.residual = residual
if self.input_dim != self.output_dim:
self.residual = False
# buffer = 1, self.input_dim, self.lorder
self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
self.buffer_size = 1 * self.input_dim * self.lorder
self.x_data_chache_size = self.lorder
self.x_data_buffer_size = self.input_dim * self.x_data_chache_size
@torch.jit.unused
def forward(self, x):
x_data = x
x = self.left_padding(x)
x = self.conv1d(x)
x = self.bn(x)
if self.stride == 1 and self.residual:
x = self.relu(x + x_data)
else:
x = self.relu(x)
x = self.dropout(x)
return x
@torch.jit.export
def infer(self, x, buffer, buffer_index, buffer_out):
# type: (Tensor) -> Tensor
x_data = x.clone()
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
[1, self.input_dim, self.lorder]
)
x = torch.cat([cnn_buffer, x], dim=2)
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
buffer_index = buffer_index + self.buffer_size
x = self.conv1d(x)
x = self.bn(x)
if self.stride == 1 and self.residual:
x_data_cnn_buffer = buffer[
buffer_index : buffer_index + self.x_data_buffer_size
].reshape([1, self.input_dim, self.x_data_chache_size])
x_data = torch.cat([x_data_cnn_buffer, x_data], dim=2)
buffer_out.append(x_data[:, :, -self.x_data_chache_size :].reshape(-1))
buffer_index = buffer_index + self.x_data_buffer_size
x_data = x_data[:, :, : -self.x_data_chache_size]
x = self.relu(x + x_data)
else:
x = self.relu(x)
return x, buffer, buffer_index, buffer_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment