Commit d75c2dc7 authored by burchim's avatar burchim
Browse files

Efficient Conformer

parents
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
# Models
from models.transducer import Transducer
from models.model_ctc import ModelCTC, InterCTC
from models.lm import LanguageModel
# Datasets
from utils.datasets import (
LibriSpeechDataset,
LibriSpeechCorpusDataset
)
# Preprocessing
from utils.preprocessing import (
collate_fn_pad
)
def create_model(config):
# Create Model
if config["model_type"] == "Transducer":
model = Transducer(
encoder_params=config["encoder_params"],
decoder_params=config["decoder_params"],
joint_params=config["joint_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "CTC":
model = ModelCTC(
encoder_params=config["encoder_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "InterCTC":
model = InterCTC(
encoder_params=config["encoder_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "LM":
model = LanguageModel(
lm_params=config["lm_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
else:
raise Exception("Unknown model type")
return model
def load_datasets(training_params, tokenizer_params, args):
# Training Datasets
training_datasets = {
"LibriSpeech": {
"class": LibriSpeechDataset,
"split": {
"training": "train",
"training-clean": "train-clean",
"validation-clean": None,
"validation-other": None,
"test-clean": None,
"test-other": None,
"eval_time": None,
"eval_time_encoder": None,
"eval_time_decoder": None,
}
},
"LibriSpeechCorpus": {
"class": LibriSpeechCorpusDataset,
"split": {
"training": "train",
"validation-clean": None,
"validation-other": None,
"test-clean": None,
"test-other": None,
"eval_time": None,
"eval_time_encoder": None,
"eval_time_decoder": None,
}
}
}
# Evaluation Datasets
evaluation_datasets = {
"LibriSpeech": {
"class": LibriSpeechDataset,
"split": {
"training": ["dev-clean", "dev-other"],
"training-clean": ["dev-clean", "dev-other"],
"validation-clean": "dev-clean",
"validation-other": "dev-other",
"test-clean": "test-clean",
"test-other": "test-other",
"eval_time": "dev-clean",
"eval_time_encoder": "dev-clean",
"eval_time_decoder": "dev-clean",
}
},
"LibriSpeechCorpus": {
"class": LibriSpeechCorpusDataset,
"split": {
"training": "val",
"validation-clean": "val",
"validation-other": "val",
"test-clean": "test",
"test-other": "test",
"eval_time": "val",
"eval_time_encoder": "val",
"eval_time_decoder": "val",
}
}
}
# Select Dataset and Split
training_dataset = training_datasets[training_params["training_dataset"]]["class"]
training_split = training_datasets[training_params["training_dataset"]]["split"][args.mode]
evaluation_dataset = evaluation_datasets[training_params["evaluation_dataset"]]["class"]
evaluation_split = evaluation_datasets[training_params["evaluation_dataset"]]["split"][args.mode]
# Training Dataset
if training_split:
if args.rank == 0:
print("Loading training dataset : {} {}".format(training_params["training_dataset"], training_split))
dataset_train = training_dataset(training_params["training_dataset_path"], training_params, tokenizer_params, training_split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset_train = torch.utils.data.DataLoader(dataset_train, batch_size=training_params["batch_size"], shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, drop_last=True, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset_train.dataset.__len__(), "samples", "/", dataset_train.__len__(), "batches")
else:
dataset_train = None
# Evaluation Dataset
if evaluation_split:
# Multiple Evaluation datasets
if isinstance(evaluation_split, list):
dataset_eval = {}
for split in evaluation_split:
if args.rank == 0:
print("Loading evaluation dataset : {} {}".format(training_params["evaluation_dataset"], split))
dataset = evaluation_dataset(training_params["evaluation_dataset_path"], training_params, tokenizer_params, split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size_eval, shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset.dataset.__len__(), "samples", "/", dataset.__len__(), "batches")
dataset_eval[split] = dataset
# One Evaluation dataset
else:
if args.rank == 0:
print("Loading evaluation dataset : {} {}".format(training_params["evaluation_dataset"], evaluation_split))
dataset_eval = evaluation_dataset(training_params["evaluation_dataset_path"], training_params, tokenizer_params, evaluation_split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset_eval, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=args.batch_size_eval, shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset_eval.dataset.__len__(), "samples", "/", dataset_eval.__len__(), "batches")
else:
dataset_eval = None
return dataset_train, dataset_eval
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pytorch
import torch
# Functions and Utils
from functions import *
from utils.preprocessing import *
# Other
import json
import argparse
import os
def main(rank, args):
# Process rank
args.rank = rank
# Distributed Computing
if args.distributed:
torch.cuda.set_device(args.rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank)
# Load Config
with open(args.config_file) as json_config:
config = json.load(json_config)
# Device
device = torch.device("cuda:" + str(args.rank) if torch.cuda.is_available() and not args.cpu else "cpu")
print("Device:", device)
# Create Tokenizer
if args.create_tokenizer:
if args.rank == 0:
print("Creating Tokenizer")
create_tokenizer(config["training_params"], config["tokenizer_params"])
if args.distributed:
torch.distributed.barrier()
# Create Model
model = create_model(config).to(device)
# Load Model
if args.initial_epoch is not None:
model.load(config["training_params"]["callback_path"] + "checkpoints_" + str(args.initial_epoch) + ".ckpt")
else:
args.initial_epoch = 0
# Load Encoder Only
if args.initial_epoch_encoder is not None:
model.load_encoder(config["training_params"]["callback_path_encoder"] + "checkpoints_" + str(args.initial_epoch_encoder) + ".ckpt")
# Load LM
if args.initial_epoch_lm:
# Load LM Config
with open(config["decoding_params"]["lm_config"]) as json_config:
config_lm = json.load(json_config)
# Create LM
model.lm = create_model(config_lm).to(device)
# Load LM
model.lm.load(config_lm["training_params"]["callback_path"] + "checkpoints_" + str(args.initial_epoch_lm) + ".ckpt")
# Model Summary
if args.rank == 0:
model.summary(show_dict=args.show_dict)
# Distribute Strategy
if args.distributed:
if args.rank == 0:
print("Parallelize model on", args.world_size, "GPUs")
model.distribute_strategy(args.rank)
# Parallel Strategy
if args.parallel and not args.distributed:
print("Parallelize model on", torch.cuda.device_count(), "GPUs")
model.parallel_strategy()
# Prepare Dataset
if args.prepare_dataset:
if args.rank == 0:
print("Preparing dataset")
prepare_dataset(config["training_params"], config["tokenizer_params"], model.tokenizer)
if args.distributed:
torch.distributed.barrier()
# Load Dataset
dataset_train, dataset_val = load_datasets(config["training_params"], config["tokenizer_params"], args)
###############################################################################
# Modes
###############################################################################
# Stochastic Weight Averaging
if args.swa:
model.swa(dataset_train, callback_path=config["training_params"]["callback_path"], start_epoch=args.swa_epochs[0] if args.swa_epochs else None, end_epoch=args.swa_epochs[1] if args.swa_epochs else None, epochs_list=args.swa_epochs_list, update_steps=args.steps_per_epoch, swa_type=args.swa_type)
# Training
elif args.mode.split("-")[0] == "training":
model.fit(dataset_train,
config["training_params"]["epochs"],
dataset_val=dataset_val,
val_steps=args.val_steps,
verbose_val=args.verbose_val,
initial_epoch=int(args.initial_epoch),
callback_path=config["training_params"]["callback_path"],
steps_per_epoch=args.steps_per_epoch,
mixed_precision=config["training_params"]["mixed_precision"],
accumulated_steps=config["training_params"]["accumulated_steps"],
saving_period=args.saving_period,
val_period=args.val_period)
# Evaluation
elif args.mode.split("-")[0] == "validation" or args.mode.split("-")[0] == "test":
# Gready Search Evaluation
if args.gready or model.beam_size is None:
if args.rank == 0:
print("Gready Search Evaluation")
wer, _, _, _ = model.evaluate(dataset_val, eval_steps=args.val_steps, verbose=args.verbose_val, beam_size=1, eval_loss=args.eval_loss)
if args.rank == 0:
print("Geady Search WER : {:.2f}%".format(100 * wer))
# Beam Search Evaluation
else:
if args.rank == 0:
print("Beam Search Evaluation")
wer, _, _, _ = model.evaluate(dataset_val, eval_steps=args.val_steps, verbose=args.verbose_val, beam_size=model.beam_size, eval_loss=False)
if args.rank == 0:
print("Beam Search WER : {:.2f}%".format(100 * wer))
# Eval Time
elif args.mode.split("-")[0] == "eval_time":
print("Model Eval Time")
inf_time = model.eval_time(dataset_val, eval_steps=args.val_steps, beam_size=1, rnnt_max_consec_dec_steps=args.rnnt_max_consec_dec_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(inf_time))
elif args.mode.split("-")[0] == "eval_time_encoder":
print("Encoder Eval Time")
enc_time = model.eval_time_encoder(dataset_val, eval_steps=args.val_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(enc_time))
elif args.mode.split("-")[0] == "eval_time_decoder":
print("Decoder Eval Time")
dec_time = model.eval_time_decoder(dataset_val, eval_steps=args.val_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(dec_time))
# Destroy Process Group
if args.distributed:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
# Args
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config_file", type=str, default="configs/EfficientConformerCTCSmall.json", help="Json configuration file containing model hyperparameters")
parser.add_argument("-m", "--mode", type=str, default="training", help="Mode : training, validation-clean, test-clean, eval_time-dev-clean, ...")
parser.add_argument("-d", "--distributed", action="store_true", help="Distributed data parallelization")
parser.add_argument("-i", "--initial_epoch", type=str, default=None, help="Load model from checkpoint")
parser.add_argument("--initial_epoch_lm", type=str, default=None, help="Load language model from checkpoint")
parser.add_argument("--initial_epoch_encoder", type=str, default=None, help="Load model encoder from encoder checkpoint")
parser.add_argument("-p", "--prepare_dataset", action="store_true", help="Prepare dataset for training")
parser.add_argument("-j", "--num_workers", type=int, default=8, help="Number of data loading workers")
parser.add_argument("--create_tokenizer", action="store_true", help="Create model tokenizer")
parser.add_argument("--batch_size_eval", type=int, default=8, help="Evaluation batch size")
parser.add_argument("--verbose_val", action="store_true", help="Evaluation verbose")
parser.add_argument("--val_steps", type=int, default=None, help="Number of validation steps")
parser.add_argument("--steps_per_epoch", type=int, default=None, help="Number of steps per epoch")
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count(), help="Number of available GPUs")
parser.add_argument("--cpu", action="store_true", help="Load model on cpu")
parser.add_argument("--show_dict", action="store_true", help="Show model dict summary")
parser.add_argument("--swa", action="store_true", help="Stochastic weight averaging")
parser.add_argument("--swa_epochs", nargs="+", default=None, help="Start epoch / end epoch for swa")
parser.add_argument("--swa_epochs_list", nargs="+", default=None, help="List of checkpoints epochs for swa")
parser.add_argument("--swa_type", type=str, default="equal", help="Stochastic weight averaging type (equal/exp)")
parser.add_argument("--parallel", action="store_true", help="Parallelize model using data parallelization")
parser.add_argument("--rnnt_max_consec_dec_steps", type=int, default=None, help="Number of maximum consecutive transducer decoder steps during inference")
parser.add_argument("--eval_loss", action="store_true", help="Compute evaluation loss during evaluation")
parser.add_argument("--gready", action="store_true", help="Proceed to a gready search evaluation")
parser.add_argument("--saving_period", type=int, default=1, help="Model saving every 'n' epochs")
parser.add_argument("--val_period", type=int, default=1, help="Model validation every 'n' epochs")
parser.add_argument("--profiler", action="store_true", help="Enable eval time profiler")
# Parse Args
args = parser.parse_args()
# Run main
if args.distributed:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
torch.multiprocessing.spawn(main, nprocs=args.world_size, args=(args,))
else:
main(0, args)
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
###############################################################################
# Activation Functions
###############################################################################
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * x.sigmoid()
class Glu(nn.Module):
def __init__(self, dim):
super(Glu, self).__init__()
self.dim = dim
def forward(self, x):
x_in, x_gate = x.chunk(2, dim=self.dim)
return x_in * x_gate.sigmoid()
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
# Layers
from models.layers import (
Linear
)
###############################################################################
# Multi-Head Attention Layers
###############################################################################
class MultiHeadAttention(nn.Module):
"""Mutli-Head Attention Layer
Args:
dim_model: model feature dimension
num_heads: number of attention heads
References:
Attention Is All You Need, Vaswani et al.
https://arxiv.org/abs/1706.03762
"""
def __init__(self, dim_model, num_heads):
super(MultiHeadAttention, self).__init__()
# Attention Params
self.num_heads = num_heads # H
self.dim_model = dim_model # D
self.dim_head = dim_model // num_heads # d
# Linear Layers
self.query_layer = Linear(self.dim_model, self.dim_model)
self.key_layer = Linear(self.dim_model, self.dim_model)
self.value_layer = Linear(self.dim_model, self.dim_model)
self.output_layer = Linear(self.dim_model, self.dim_model)
def forward(self, Q, K, V, mask=None):
"""Scaled Dot-Product Multi-Head Attention
Args:
Q: Query of shape (B, T, D)
K: Key of shape (B, T, D)
V: Value of shape (B, T, D)
mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
Return:
O: Attention output of shape (B, T, D)
att_w: Attention weights of shape (B, H, T, T)
"""
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
Q = Q.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Att scores (B, H, T, T)
att_scores = Q.matmul(K.transpose(2, 3)) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
att_scores += (mask * -1e9)
# Att weights (B, H, T, T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, H, T, d) -> (B, T, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
O = self.output_layer(O)
return O, att_w
def pad(self, Q, K, V, mask, chunk_size):
# Compute Overflows
overflow_Q = Q.size(1) % chunk_size
overflow_KV = K.size(1) % chunk_size
padding_Q = chunk_size - overflow_Q if overflow_Q else 0
padding_KV = chunk_size - overflow_KV if overflow_KV else 0
batch_size, seq_len_KV, _ = K.size()
# Input Padding (B, T, D) -> (B, T + P, D)
Q = F.pad(Q, (0, 0, 0, padding_Q), value=0)
K = F.pad(K, (0, 0, 0, padding_KV), value=0)
V = F.pad(V, (0, 0, 0, padding_KV), value=0)
# Update Padding Mask
if mask is not None:
# (B, 1, 1, T) -> (B, 1, 1, T + P)
if mask.size(2) == 1:
mask = F.pad(mask, pad=(0, padding_KV), value=1)
# (B, 1, T, T) -> (B, 1, T + P, T + P)
else:
mask = F.pad(mask, pad=(0, padding_Q, 0, padding_KV), value=1)
elif padding_KV:
# None -> (B, 1, 1, T + P)
mask = F.pad(Q.new_zeros(batch_size, 1, 1, seq_len_KV), pad=(0, padding_KV), value=1)
return Q, K, V, mask, padding_Q
class GroupedMultiHeadAttention(MultiHeadAttention):
"""Grouped Mutli-Head Attention Layer
Grouped multi-head attention reduces attention complexity from O(T2·D) to O(T2·D/G)
by grouping neighbouring time elements along the feature dimension before applying
scaled dot-product attention.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
group_size: attention group size
"""
def __init__(self, dim_model, num_heads, group_size):
super(GroupedMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.group_size = group_size # G
self.dim_head = (self.group_size * dim_model) // self.num_heads # d
def forward(self, Q, K, V, mask=None):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.group_size)
# Reshape and Transpose (B, T, D) -> (B, H, T//G, d)
Q = Q.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Att scores (B, H, T//G, T//G)
att_scores = Q.matmul(K.transpose(2, 3)) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
# Slice Mask (B, 1, T, T) -> (B, 1, T//G, T//G)
mask = mask[:, :, ::self.group_size, ::self.group_size]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//G, T//G)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//G, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, H, T//G, d) -> (B, T, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :O.size(1) - padding]
# Output linear layer
O = self.output_layer(O)
return O, att_w
class LocalMultiHeadAttention(MultiHeadAttention):
"""Local Multi-Head Attention Layer
Local multi-head attention restricts the attended positions to a local neighborhood
around the query position. This is achieved by segmenting the hidden sequence into
non overlapping blocks of size K and performing scaled dot-product attention in
parallel for each of these blocks.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
kernel_size: attention kernel size / window
References:
Image Transformer, Parmar et al.
https://arxiv.org/abs/1802.05751
"""
def __init__(self, dim_model, num_heads, kernel_size):
super(LocalMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.kernel_size = kernel_size # K
def forward(self, Q, K, V, mask=None):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.kernel_size)
# Reshape and Transpose (B, T, D) -> (B, T//K, H, K, d)
Q = Q.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
K = K.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
V = V.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Att scores (B, T//K, H, K, K)
att_scores = Q.matmul(K.transpose(3, 4)) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
# Slice mask (B, 1, T, T) -> (B, T//K, 1, K, K)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size : (m + 1) * self.kernel_size, m * self.kernel_size : (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Apply mask
att_scores = att_scores.float() - mask.float() * 1e9
# Att weights (B, T//K, H, K, K)
att_w = att_scores.softmax(dim=-1)
# Att output (B, T//K, H, K, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, T//K, H, K, d) -> (B, T, D)
O = O.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :O.size(1) - padding]
# Output linear layer
O = self.output_layer(O)
return O, att_w
class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Mutli-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
the attention query bedore aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
Args:
dim_model: model feature dimension
num_heads: number of attention heads
stride: query stride
"""
def __init__(self, dim_model, num_heads, stride):
super(StridedMultiHeadAttention, self).__init__(dim_model, num_heads)
# Attention Params
self.stride = stride # S
def forward(self, Q, K, V, mask=None):
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Mask Subsampling (B, 1, T, T) -> (B, 1, T//S, T)
if mask is not None:
mask = mask[:, :, ::self.stride]
# Multi-Head Attention
return super(StridedMultiHeadAttention).forward(Q, K, V, mask)
class StridedLocalMultiHeadAttention(MultiHeadAttention):
"""Strided Local Multi-Head Attention Layer
Args:
dim_model: model feature dimension
num_heads: number of attention heads
kernel_size: attention kernel size / window
stride: query stride
"""
def __init__(self, dim_model, num_heads, kernel_size, stride):
super(StridedLocalMultiHeadAttention, self).__init__(dim_model, num_heads)
# Assert
assert kernel_size % stride == 0, "Attention kernel size has to be a multiple of attention stride"
# Attention Params
self.kernel_size = kernel_size # K
self.stride = stride # S
def forward(self, Q, K, V, mask=None):
# Batch size B
batch_size = Q.size(0)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.kernel_size)
# Reshape and Transpose (B, T//S, D) -> (B, T//K, H, K//S, d)
Q = Q.reshape(batch_size, -1, self.kernel_size//self.stride, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, T, D) -> (B, T//K, H, K, d)
K = K.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
V = V.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Att scores (B, T//K, H, K//S, K)
att_scores = Q.matmul(K.transpose(3, 4)) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
# Slice mask (B, 1, T, T) -> (B, T//K, 1, K, K)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size : (m + 1) * self.kernel_size, m * self.kernel_size : (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Subsample mask (B, T//K, 1, K, K) -> (B, T//K, 1, K//S, K)
mask = mask[:, :, :, ::self.stride]
# Apply mask
att_scores = att_scores.float() - mask.float() * 1e9
# Att weights (B, T//K, H, K//S, K)
att_w = att_scores.softmax(dim=-1)
# Att output (B, T//K, H, K//S, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, T//K, H, K//S, d) -> (B, T//S, D)
O = O.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :(O.size(1) - padding - 1)//self.stride + 1]
# Output linear layer
O = self.output_layer(O)
return O, att_w
class MultiHeadLinearAttention(MultiHeadAttention):
"""Multi-Head Linear Attention
Args:
dim_model: model feature dimension
num_heads: number of attention heads
References:
Efficient Attention: Attention with Linear Complexities, Shen et al.
https://arxiv.org/abs/1812.01243
Efficient conformer-based speech recognition with linear attention, Li et al.
https://arxiv.org/abs/2104.06865
"""
def __init__(self, dim_model, num_heads):
super(MultiHeadLinearAttention, self).__init__(dim_model, num_heads)
def forward(self, Q, K, V):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Reshape and Transpose (B, T, D) -> (B, N, T, d)
Q = Q.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Global Context Vector (B, N, d, d)
KV = (K / K.shape[-1]**(1.0/4.0)).softmax(dim=-2).transpose(2, 3).matmul(V)
# Attention Output (B, N, T, d)
O = (Q / Q.shape[-1]**(1.0/4.0)).softmax(dim=-1).matmul(KV)
# Transpose and Reshape (B, N, T, d) -> (B, T, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
O = self.output_layer(O)
return O, KV
###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
###############################################################################
class RelPosMultiHeadSelfAttention(MultiHeadAttention):
"""Multi-Head Self-Attention Layer with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
References:
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Dai et al.
https://arxiv.org/abs/1901.02860
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding):
super(RelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads)
# Position Embedding Layer
self.pos_layer = nn.Linear(self.dim_model, self.dim_model)
self.causal = causal
# Global content and positional bias
self.u = nn.Parameter(torch.Tensor(self.dim_model)) # Content bias
self.v = nn.Parameter(torch.Tensor(self.dim_model)) # Pos bias
torch.nn.init.xavier_uniform_(self.u.reshape(self.num_heads, self.dim_head)) # glorot uniform
torch.nn.init.xavier_uniform_(self.v.reshape(self.num_heads, self.dim_head)) # glorot uniform
# Relative Sinusoidal Positional Encodings
self.rel_pos_enc = RelativeSinusoidalPositionalEncoding(max_pos_encoding, self.dim_model, self.causal)
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T, Th + 2*T-1) for full context and (B, H, T, Th + T) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, H, T, Th + T)
References:
causal context:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
full context:
Attention Augmented Convolutional Networks, Bello et al.
https://arxiv.org/abs/1904.09925
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T, Th + T)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T, 1 + Th + T)
att_scores = F.pad(att_scores, pad=(1, 0), value=0)
# Flatten (B, H, T + TTh + TT)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# Start Padding (B, H, Th + T + TTh + TT)
att_scores = F.pad(att_scores, pad=(seq_length2 - seq_length1, 0), value=0)
# Reshape (B, H, 1 + T, Th + T)
att_scores = att_scores.reshape(batch_size, num_heads, 1 + seq_length1, seq_length2)
# Slice (B, H, T, Th + T)
att_scores = att_scores[:, :, 1:]
# Full Context
else:
# Att Scores (B, H, T, Th + 2*T-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T, Th + 2*T)
att_scores = F.pad(att_scores, pad=(0, 1), value=0)
# Flatten (B, H, TTh + 2*TT)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# End Padding (B, H, TTh + 2*TT + Th + T - 1)
att_scores = F.pad(att_scores, pad=(0, seq_length2 - seq_length1), value=0)
# Reshape (B, H, T + 1, Th + 2*T-1)
att_scores = att_scores.reshape(batch_size, num_heads, 1 + seq_length1, seq_length2)
# Slice (B, H, T, Th + T)
att_scores = att_scores[:, :, :seq_length1, seq_length1-1:]
return att_scores
def forward(self, Q, K, V, mask=None, hidden=None):
"""Scaled Dot-Product Self-Attention with relative sinusoidal position encodings
Args:
Q: Query of shape (B, T, D)
K: Key of shape (B, T, D)
V: Value of shape (B, T, D)
mask: Optional position mask of shape (1 or B, 1 or H, 1 or T, 1 or T)
hidden: Optional Key and Value hidden states for decoding
Return:
O: Attention output of shape (B, T, D)
att_w: Attention weights of shape (B, H, T, Th + T)
hidden: Key and value hidden states
"""
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Hidden State Provided
if hidden:
K = torch.cat([hidden["K"], K], dim=1)
V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State
hidden = {"K": K, "V": V}
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, Q.size(1), K.size(1) - Q.size(1)))
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th + T, d)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-1, D) -> (B, H, Th + 2*T-1, d) / (B, Th + T, D) -> (B, H, Th + T, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T, Th + T)
att_scores_K = Qu.matmul(K.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
att_scores += (mask * -1e9)
# Att weights (B, H, T, Th + T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, H, T, d) -> (B, T, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Grouped Multi-Head Self-Attention Layer with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
group_size: attention group size
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding, group_size):
super(GroupedRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, max_pos_encoding)
# Attention Params
self.group_size = group_size # G
self.dim_head = (self.group_size * dim_model) // self.num_heads # d
# Grouped Relative Sinusoidal Positional Encodings
self.rel_pos_enc = GroupedRelativeSinusoidalPositionalEncoding(max_pos_encoding, self.dim_model, self.group_size, self.causal)
def forward(self, Q, K, V, mask=None, hidden=None):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Hidden State Provided
if hidden:
Kh = torch.cat([hidden["K"], K], dim=1)
Vh = torch.cat([hidden["V"], V], dim=1)
K = torch.cat([hidden["K"][:, hidden["K"].size(1)%self.group_size:], K], dim=1)
V = torch.cat([hidden["V"][:, hidden["V"].size(1)%self.group_size:], V], dim=1)
# Update Hidden State
hidden = {"K": Kh, "V": Vh}
else:
# Update Hidden State
hidden = {"K": K, "V": V}
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.group_size)
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Relative Positional Embeddings (B, Th + 2*T-G, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, Q.size(1), K.size(1) - Q.size(1)))
# Reshape and Transpose (B, T, D) -> (B, H, T//G, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th//G + T//G, d)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-G, D) -> (B, H, Th//G + 2*T//G-1, d) / (B, Th + T, D) -> (B, H, Th//G + T//G, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T//G, Th//G + T//G)
att_scores_K = Qu.matmul(K.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
# Slice Mask (B, 1, T, T) -> (B, 1, T//G, T//G)
mask = mask[:, :, ::self.group_size, ::self.group_size]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//G, Th//G + T//G)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//G, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, H, T//G, d) -> (B, T, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :O.size(1) - padding]
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Local Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
kernel_size: attention kernel size / window
References:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
"""
def __init__(self, dim_model, num_heads, causal, kernel_size):
super(LocalRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, kernel_size)
# Attention Params
self.kernel_size = kernel_size # K
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, N, T, 2 * K - 1) for full context and (B, H, T, K) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, T//K, H, K, K)
References:
Causal context:
Music Transformer, Huang et al.
https://arxiv.org/abs/1809.04281
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T, K)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//K, H, K, K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size, self.kernel_size)
# Column Padding (B, T//K, H, K, 1 + K)
att_scores = F.pad(att_scores, pad=(1, 0), value=0)
# Reshape (B, T//K, H, 1 + K, K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size + 1, self.kernel_size)
# Slice (B, T//K, H, K, K)
att_scores = att_scores[:, :, :, 1:]
# Full Context
else:
# Att Scores (B, H, T, 2 * K - 1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//K, H, K, 2 * K - 1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size, seq_length2)
# Column Padding (B, T//K, H, K, 2 * K)
att_scores = F.pad(att_scores, pad=(0, 1), value=0)
# Flatten (B, T//K, H, K * 2 * K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, 2 * self.kernel_size**2)
# End Padding (B, T//K, H, K * 2 * K + K - 1)
att_scores = F.pad(att_scores, pad=(0, self.kernel_size - 1), value=0)
# Reshape (B, T//K, H, K + 1, 2 * K - 1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size + 1, seq_length2)
# Slice (B, T//K, H, K, K)
att_scores = att_scores[:, :, :, :self.kernel_size, self.kernel_size - 1:]
return att_scores
def forward(self, Q, K, V, mask=None, hidden=None):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.kernel_size)
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Relative Positional Embeddings (B, 2*K-1, D) / (B, K, D)
E = self.pos_layer(self.rel_pos_enc(batch_size))
# Reshape and Transpose (B, T, D) -> (B, H, T, d)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, T, D) -> (B, T//K, H, K, d)
Qu = Qu.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
K = K.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
V = V.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, 2*K-1, D) -> (B, H, 2*K-1, d) / (B, K, D) -> (B, H, K, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, T//K, H, K, K)
att_scores_K = Qu.matmul(K.transpose(3, 4))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5
# Mask scores
if mask is not None:
# Diagonal Mask (B, 1, T, T) -> (B, T//K, 1, K, K)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size : (m + 1) * self.kernel_size, m * self.kernel_size : (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Apply Mask
att_scores = att_scores.float() - mask.float() * 1e9
# Attention weights (B, T//K, H, K, K)
att_w = att_scores.softmax(dim=-1)
# Attention output (B, T//K, H, K, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, T//K, H, K, d) -> (B, T, D)
O = O.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :O.size(1) - padding]
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Strided Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
max_pos_encoding: maximum relative distance between elements
stride: query stride
"""
def __init__(self, dim_model, num_heads, causal, max_pos_encoding, stride):
super(StridedRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, max_pos_encoding)
# Attention Params
self.stride = stride # S
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T//S, Th + 2 * T - 1) for full context and (B, H, T//S, Th + T) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, H, T//S,Th + T)
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T // S, Th + T)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T // S, Th + T + S)
att_scores = F.pad(att_scores, pad=(1, self.stride-1), value=0)
# Flatten (B, H, TTh//S + TT//S + T)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# Start Padding (B, H, TTh//S + TT//S + T + Th)
att_scores = F.pad(att_scores, pad=(seq_length2 - self.stride*seq_length1, 0), value=0)
# Reshape (B, H, 1 + T // S, Th + T)
att_scores = att_scores.reshape(batch_size, num_heads, seq_length1 + 1, seq_length2)
# Slice (B, H, T // S, Th + T)
att_scores = att_scores[:, :, 1:]
# Full Context
else:
# Att Scores (B, H, T // S, Th + 2*T-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Column Padding (B, H, T // S, Th + 2*T-1 + S)
att_scores = F.pad(att_scores, pad=(0, self.stride), value=0)
# Flatten (B, H, TTh//S + 2*TT//S - T//S + T)
att_scores = att_scores.reshape(batch_size, num_heads, -1)
# End Padding (B, H, TTh//S + 2*TT//S - T//S + Th + 2T-1)
att_scores = F.pad(att_scores, pad=(0, seq_length2 - seq_length1 * self.stride), value=0)
# Reshape (B, H, T//S + 1, Th + 2*T-1)
att_scores = att_scores.reshape(batch_size, num_heads, seq_length1 + 1, seq_length2)
# Slice (B, H, T // S, Th + T)
att_scores = att_scores[:, :, :seq_length1, seq_length1 * self.stride - 1:]
return att_scores
def forward(self, Q, K, V, mask=None, hidden=None):
# Batch size B
batch_size = Q.size(0)
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Hidden State Provided
if hidden:
K = torch.cat([hidden["K"], K], dim=1)
V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State
hidden = {"K": K, "V": V}
# Chunk Padding
Q, K, V, mask, _ = self.pad(Q, K, V, mask, chunk_size=self.stride)
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, self.stride * Q.size(1), K.size(1) - self.stride * Q.size(1)))
# Reshape and Transpose (B, T//S, D) -> (B, H, T//S, d)
Qu = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
Qv = Qv.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + T, D) -> (B, H, Th + T, d)
K = K.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
V = V.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, Th + 2*T-1, D) -> (B, H, Th + 2*T-1, d) / (B, Th + T, D) -> (B, H, Th + T, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, H, T//S, Th + T)
att_scores_K = Qu.matmul(K.transpose(2, 3))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5
# Apply mask
if mask is not None:
# Mask Subsampling (B, 1, T, T) -> (B, 1, T//S, T)
if mask is not None:
mask = mask[:, :, ::self.stride]
# Apply mask
att_scores += (mask * -1e9)
# Att weights (B, H, T//S, Th + T)
att_w = att_scores.softmax(dim=-1)
# Att output (B, H, T//S, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, H, T//S, d) -> (B, T//S, D)
O = O.transpose(1, 2).reshape(batch_size, -1, self.dim_model)
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
"""Strided Local Multi-Head Self-Attention with Relative Sinusoidal Positional Encodings
Args:
dim_model: model feature dimension
num_heads: number of attention heads
causal: whether the attention is causal or unmasked
kernel_size: attention kernel size / window
stride: query stride
"""
def __init__(self, dim_model, num_heads, causal, kernel_size, stride):
super(StridedLocalRelPosMultiHeadSelfAttention, self).__init__(dim_model, num_heads, causal, kernel_size)
# Assert
assert kernel_size % stride == 0, "Attention kernel size has to be a multiple of attention stride"
# Attention Params
self.kernel_size = kernel_size # K
self.stride = stride # S
def rel_to_abs(self, att_scores):
"""Relative to absolute position indexing
Args:
att_scores: absolute-by-relative indexed attention scores of shape
(B, H, T//S, 2 * K - 1) for full context and (B, H, T//S, K) for causal context
Return:
att_scores: absolute-by-absolute indexed attention scores of shape (B, T//K, H, K//S, K)
"""
# Causal Context
if self.causal:
# Att Scores (B, H, T//S, K)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//K, H, K//S, K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size//self.stride, self.kernel_size)
# Column Padding (B, T//K, H, K//S, K + S)
att_scores = F.pad(att_scores, pad=(1, self.stride - 1), value=0)
# Reshape (B, T//K, H, 1 + K//S, K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size//self.stride + 1, self.kernel_size)
# Slice (B, T//K, H, K//S, K)
att_scores = att_scores[:, :, :, 1:]
# Full Context
else:
# Att Scores (B, H, T//S, 2*K-1)
batch_size, num_heads, seq_length1, seq_length2 = att_scores.size()
# Reshape (B, T//K, H, K//S, 2*K-1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size//self.stride, seq_length2)
# Column Padding (B, T//K, H, K//S, 2*K-1 + S)
att_scores = F.pad(att_scores, pad=(0, self.stride), value=0)
# Flatten (B, T//K, H, 2KK//S - K//S + K)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size//self.stride * (2 * self.kernel_size - 1 + self.stride))
# End Padding (B, T//K, H, 2KK//S - K//S + 2K-1)
att_scores = F.pad(att_scores, pad=(0, self.kernel_size - 1), value=0)
# Reshape (B, T//K, H, K//S + 1, 2*K-1)
att_scores = att_scores.reshape(batch_size, -1, self.num_heads, self.kernel_size//self.stride + 1, seq_length2)
# Slice (B, T//K, H, K//S, K)
att_scores = att_scores[:, :, :, :self.kernel_size//self.stride, self.kernel_size - 1:]
return att_scores
def forward(self, Q, K, V, mask=None, hidden=None):
# Batch size B
batch_size = Q.size(0)
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.kernel_size)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Linear Layers
Q = self.query_layer(Q)
K = self.key_layer(K)
V = self.value_layer(V)
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Relative Positional Embeddings (B, 2*K-1, D) / (B, K, D)
E = self.pos_layer(self.rel_pos_enc(batch_size))
# Reshape and Transpose (B, T//S, D) -> (B, H, T//S, d)
Qv = Qu.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# Reshape and Transpose (B, T//S, D) -> (B, T//K, H, K//S, d)
Qu = Qv.reshape(batch_size, -1, self.kernel_size//self.stride, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, T, D) -> (B, T//K, H, K, d)
K = K.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
V = V.reshape(batch_size, -1, self.kernel_size, self.num_heads, self.dim_head).transpose(2, 3)
# Reshape and Transpose (B, 2*K-1, D) -> (B, H, 2*K-1, d) / (B, K, D) -> (B, H, K, d)
E = E.reshape(batch_size, -1, self.num_heads, self.dim_head).transpose(1, 2)
# att_scores (B, T//K, H, K//S, K)
att_scores_K = Qu.matmul(K.transpose(3, 4))
att_scores_E = self.rel_to_abs(Qv.matmul(E.transpose(2, 3)))
att_scores = (att_scores_K + att_scores_E) / K.shape[-1]**0.5
# Mask scores
if mask is not None:
# Diagonal Mask (B, 1, T, T) -> (B, T//K, 1, K, K)
masks = []
for m in range(mask.size(-1) // self.kernel_size):
masks.append(mask[:, :, m * self.kernel_size : (m + 1) * self.kernel_size, m * self.kernel_size : (m + 1) * self.kernel_size])
mask = torch.stack(masks, dim=1)
# Stride Mask (B, T//K, 1, K, K) -> (B, T//K, 1, K//S, K)
mask = mask[:, :, :, ::self.stride]
# Apply Mask
att_scores = att_scores.float() - mask.float() * 1e9
# Attention weights (B, T//K, H, K//S, K)
att_w = att_scores.softmax(dim=-1)
# Attention output (B, T//K, H, K//S, d)
O = att_w.matmul(V)
# Transpose and Reshape (B, T//K, H, K//S, d) -> (B, T//S, D)
O = O.transpose(2, 3).reshape(batch_size, -1, self.dim_model)
# Slice Padding
O = O[:, :(self.stride*O.size(1) - padding - 1)//self.stride + 1]
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
###############################################################################
# Positional Encodings
###############################################################################
class SinusoidalPositionalEncoding(nn.Module):
"""
Sinusoidal Positional Encoding
Reference: "Attention Is All You Need" by Vaswani et al.
https://arxiv.org/abs/1706.03762
"""
def __init__(self, max_len, dim_model):
super(SinusoidalPositionalEncoding, self).__init__()
pos_encoding = torch.zeros(max_len, dim_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
i = torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0)
angles = pos / 10000**(2 * i / dim_model)
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
def forward(self, batch_size=1, seq_len=None):
# (B, T, D)
if seq_len is not None:
P = self.pos_encoding[:, :seq_len]
# (B, Tmax, D)
else:
P = self.pos_encoding
return P.repeat(batch_size, 1, 1)
class RelativeSinusoidalPositionalEncoding(nn.Module):
"""
Relative Sinusoidal Positional Encoding
Positional encoding for left context (sin) and right context (cos)
Total context = 2 * max_len - 1
"""
def __init__(self, max_len, dim_model, causal=False):
super(RelativeSinusoidalPositionalEncoding, self).__init__()
# PE
pos_encoding = torch.zeros(2 * max_len - 1, dim_model)
# Positions (max_len - 1, ..., max_len - 1)
pos_left = torch.arange(start=max_len-1, end=0, step=-1, dtype=torch.float)
pos_right = torch.arange(start=0, end=-max_len, step=-1, dtype=torch.float)
pos = torch.cat([pos_left, pos_right], dim=0).unsqueeze(1)
# Angles
angles = pos / 10000**(2 * torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0) / dim_model)
# Rel Sinusoidal PE
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
self.max_len = max_len
self.causal = causal
def forward(self, batch_size=1, seq_len=None, hidden_len=0):
# Causal Context
if self.causal:
# (B, Th + T, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len : self.max_len]
# (B, Tmax, D)
else:
R = self.pos_encoding[:,:self.max_len]
# Full Context
else:
# (B, Th + 2*T-1, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len : self.max_len - 1 + seq_len]
# (B, 2*Tmax-1, D)
else:
R = self.pos_encoding
return R.repeat(batch_size, 1, 1)
class GroupedRelativeSinusoidalPositionalEncoding(nn.Module):
"""
Relative Sinusoidal Positional Encoding for grouped multi-head attention
Positional encoding for left context (sin) and right context (cos)
Total context = 2 * max_len - group_size
"""
def __init__(self, max_len, dim_model, group_size=1, causal=False):
super(GroupedRelativeSinusoidalPositionalEncoding, self).__init__()
# PE
pos_encoding = torch.zeros(2 * max_len - group_size % 2, dim_model)
# Positions (max_len - 1, ..., max_len - 1)
pos_left = torch.arange(start=max_len-1, end=group_size % 2 - 1, step=-1, dtype=torch.float)
pos_right = torch.arange(start=0, end=-max_len, step=-1, dtype=torch.float)
pos = torch.cat([pos_left, pos_right], dim=0).unsqueeze(1)
# Angles
angles = pos / 10000**(2 * torch.arange(0, dim_model // 2, dtype=torch.float).unsqueeze(0) / dim_model)
# Rel Sinusoidal PE
pos_encoding[:, 0::2] = angles.sin()
pos_encoding[:, 1::2] = angles.cos()
pos_encoding = pos_encoding.unsqueeze(0)
self.register_buffer('pos_encoding', pos_encoding, persistent=False)
self.max_len = max_len
self.causal = causal
self.group_size = group_size
def forward(self, batch_size=1, seq_len=None, hidden_len=0):
# Causal Context
if self.causal:
# (B, Th + T, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len - hidden_len : self.max_len]
# (B, Tmax, D)
else:
R = self.pos_encoding[:,:self.max_len]
else:
# (B, Th + 2*T-G, D)
if seq_len is not None:
R = self.pos_encoding[:, self.max_len - seq_len + self.group_size // 2 - hidden_len : self.max_len - self.group_size % 2 + seq_len - self.group_size // 2 ]
# (B, 2*Tmax-G, D)
else:
R = self.pos_encoding
return R.repeat(batch_size, 1, 1)
###############################################################################
# Attention Masks
###############################################################################
class PaddingMask(nn.Module):
def __init__(self):
super(PaddingMask, self).__init__()
def forward(self, seq_len, x_len):
if x_len is not None:
mask = x_len.new_ones(x_len.size(0), seq_len)
for b in range(x_len.size(0)):
mask[b, :x_len[b]] = x_len.new_zeros(x_len[b])
# Padding Mask (B, 1, 1, T)
return mask[:, None, None, :]
else:
return None
class LookAheadMask(nn.Module):
def __init__(self):
super(LookAheadMask, self).__init__()
self.padding_mask = PaddingMask()
def forward(self, x, x_len):
# Seq Length T
seq_len = x.size(-1)
# Look Ahead Mask (T, T)
look_ahead_mask = x.new_ones(seq_len, seq_len).triu(diagonal=1)
if x_len is not None:
# Padding Mask (B, 1, 1, T)
padding_mask = self.padding_mask(seq_len, x_len)
# Look Ahead Mask + Padding Mask (B, 1, T, T)
return look_ahead_mask.maximum(padding_mask)
else:
# Look Ahead Mask + Padding Mask (1, 1, T, T)
return look_ahead_mask[None, None, :, :]
class StreamingMask(nn.Module):
def __init__(self, left_context, right_context):
super(StreamingMask, self).__init__()
self.padding_mask = PaddingMask()
self.left_context = left_context
self.right_context = right_context
def forward(self, x, x_len):
# Seq Length T
seq_len = x.size(-1)
# Right Context Mask (T, T)
right_context_mask = x.new_ones(seq_len, seq_len).triu(diagonal=1+self.right_context)
# Left Context Mask (T, T)
left_context_mask = 1 - x.new_ones(seq_len, seq_len).triu(diagonal=-self.left_context)
# Streaming Mask (T, T)
streaming_mask = right_context_mask.max(left_context_mask)
# Padding Mask
if x_len is not None:
# Padding Mask (B, 1, 1, T)
padding_mask = self.padding_mask(seq_len, x_len)
# Streaming Mask + Padding Mask (B, 1, T, T)
return streaming_mask.maximum(padding_mask)
else:
# Streaming Mask (1, 1, T, T)
return streaming_mask[None, None, :, :]
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Modules
from models.modules import (
FeedForwardModule,
MultiHeadSelfAttentionModule,
ConvolutionModule
)
# Layers
from models.layers import (
Conv1d,
Transpose
)
class ConformerBlock(nn.Module):
def __init__(
self,
dim_model,
dim_expand,
ff_ratio,
num_heads,
kernel_size,
att_group_size,
att_kernel_size,
linear_att,
Pdrop,
relative_pos_enc,
max_pos_encoding,
conv_stride,
att_stride,
causal
):
super(ConformerBlock, self).__init__()
# Feed Forward Module 1
self.feed_forward_module1 = FeedForwardModule(
dim_model=dim_model,
dim_ffn=dim_model * ff_ratio,
Pdrop=Pdrop,
act="swish",
inner_dropout=True
)
# Multi-Head Self-Attention Module
self.multi_head_self_attention_module = MultiHeadSelfAttentionModule(
dim_model=dim_model,
num_heads=num_heads,
Pdrop=Pdrop,
max_pos_encoding=max_pos_encoding,
relative_pos_enc=relative_pos_enc,
causal=causal,
group_size=att_group_size,
kernel_size=att_kernel_size,
stride=att_stride,
linear_att=linear_att
)
# Convolution Module
self.convolution_module = ConvolutionModule(
dim_model=dim_model,
dim_expand=dim_expand,
kernel_size=kernel_size,
Pdrop=Pdrop,
stride=conv_stride,
padding="causal" if causal else "same"
)
# Feed Forward Module 2
self.feed_forward_module2 = FeedForwardModule(
dim_model=dim_expand,
dim_ffn=dim_expand * ff_ratio,
Pdrop=Pdrop,
act="swish",
inner_dropout=True
)
# Block Norm
self.norm = nn.LayerNorm(dim_expand, eps=1e-6)
# Attention Residual
self.att_res = nn.Sequential(
Transpose(1, 2),
nn.MaxPool1d(kernel_size=1, stride=att_stride),
Transpose(1, 2)
) if att_stride > 1 else nn.Identity()
# Convolution Residual
self.conv_res = nn.Sequential(
Transpose(1, 2),
Conv1d(dim_model, dim_expand, kernel_size=1, stride=conv_stride),
Transpose(1, 2)
) if dim_model != dim_expand else nn.Sequential(
Transpose(1, 2),
nn.MaxPool1d(kernel_size=1, stride=conv_stride),
Transpose(1, 2)
) if conv_stride > 1 else nn.Identity()
# Bloc Stride
self.stride = conv_stride * att_stride
def forward(self, x, mask=None, hidden=None):
# FFN Module 1
x = x + 1/2 * self.feed_forward_module1(x)
# MHSA Module
x_att, attention, hidden = self.multi_head_self_attention_module(x, mask, hidden)
x = self.att_res(x) + x_att
# Conv Module
x = self.conv_res(x) + self.convolution_module(x)
# FFN Module 2
x = x + 1/2 * self.feed_forward_module2(x)
# Block Norm
x = self.norm(x)
return x, attention, hidden
class TransformerBlock(nn.Module):
def __init__(self, dim_model, ff_ratio, num_heads, Pdrop, max_pos_encoding, relative_pos_enc, causal):
super(TransformerBlock, self).__init__()
# Muti-Head Self-Attention Module
self.multi_head_self_attention_module = MultiHeadSelfAttentionModule(
dim_model=dim_model,
num_heads=num_heads,
Pdrop=Pdrop,
max_pos_encoding=max_pos_encoding,
relative_pos_enc=relative_pos_enc,
causal=causal,
group_size=1,
kernel_size=1,
stride=1,
efficient_att=False
)
# Feed Forward Module
self.feed_forward_module = FeedForwardModule(
dim_model=dim_model,
dim_ffn=dim_model * ff_ratio,
Pdrop=Pdrop,
act="relu",
inner_dropout=False
)
def forward(self, x, mask=None, hidden=None):
# Muti-Head Self-Attention Module
x_att, attention, hidden = self.multi_head_self_attention_module(x, mask, hidden)
x = x + x_att
# Feed Forward Module
x = x + self.feed_forward_module(x)
return x, attention, hidden
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Positional Encodings and Masks
from models.attentions import (
SinusoidalPositionalEncoding,
StreamingMask
)
# Blocks
from models.blocks import (
TransformerBlock,
ConformerBlock
)
# Layers
from models.layers import (
Embedding,
LSTM
)
###############################################################################
# Decoder Models
###############################################################################
class RnnDecoder(nn.Module):
def __init__(self, params):
super(RnnDecoder, self).__init__()
self.embedding = Embedding(params["vocab_size"], params["dim_model"], padding_idx=0)
self.rnn = LSTM(input_size=params["dim_model"], hidden_size=params["dim_model"], num_layers=params["num_layers"], batch_first=True, bidirectional=False)
def forward(self, y, hidden, y_len=None):
# Sequence Embedding (B, U + 1) -> (N, U + 1, D)
y = self.embedding(y)
# Pack padded batch sequences
if y_len is not None:
y = nn.utils.rnn.pack_padded_sequence(y, y_len.cpu(), batch_first=True, enforce_sorted=False)
# Hidden state provided
if hidden is not None:
y, hidden = self.rnn(y, hidden)
# None hidden state
else:
y, hidden = self.rnn(y)
# Pad packed batch sequences
if y_len is not None:
y, _ = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
# return last layer steps outputs and every layer last step hidden state
return y, hidden
class TransformerDecoder(nn.Module):
def __init__(self, params):
super(TransformerDecoder, self).__init__()
# Look Ahead Mask
self.look_ahead_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0)
# Embedding Layer
self.embedding = nn.Embedding(params["vocab_size"], params["dim_model"], padding_idx=0)
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"])
# Transformer Blocks
self.blocks = nn.ModuleList([TransformerBlock(
dim_model=params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"],
Pdrop=params["Pdrop"],
max_pos_encoding=params["max_pos_encoding"],
relative_pos_enc=params["relative_pos_enc"],
causal=True
) for block_id in range(params["num_blocks"])])
def forward(self, y, hidden, y_len=None):
# Look Ahead Mask
if hidden == None:
mask = self.look_ahead_mask(y, y_len)
else:
mask = None
# Linear Proj
y = self.embedding(y)
# Dropout
y = self.dropout(y)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
y = y + self.pos_enc(y.size(0), y.size(1))
# Transformer Blocks
attentions = []
hidden_new = []
for block_id, block in enumerate(self.blocks):
# Hidden State Provided
if hidden is not None:
y, attention, block_hidden = block(y, mask, hidden[block_id])
else:
y, attention, block_hidden = block(y, mask)
# Update Hidden / Att Maps
if not self.training:
attentions.append(attention)
hidden_new.append(block_hidden)
return y, hidden_new
class ConformerDecoder(nn.Module):
def __init__(self, params):
super(ConformerDecoder, self).__init__()
# Look Ahead Mask
self.look_ahead_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0)
# Embedding Layer
self.embedding = nn.Embedding(params["vocab_size"], params["dim"], padding_idx=0)
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"])
# Conformer Layers
self.blocks = nn.ModuleList([ConformerBlock(
dim_model=params["dim_model"],
dim_expand=params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"],
kernel_size=params["kernel_size"],
att_group_size=1,
att_kernel_size=None,
Pdrop=params["Pdrop"],
relative_pos_enc=params["relative_pos_enc"],
max_pos_encoding=params["max_pos_encoding"],
conv_stride=1,
att_stride=1,
causal=True
) for block_id in range(params["num_blocks"])])
def forward(self, y, hidden, y_len=None):
# Hidden state provided
if hidden is not None:
y = torch.cat([hidden, y], axis=1)
# Look Ahead Mask
mask = self.look_ahead_mask(y, y_len)
# Update Hidden
hidden_new = y
# Linear Proj
y = self.embedding(y)
# Dropout
y = self.dropout(y)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
y = y + self.pos_enc(y.size(0), y.size(1))
# Transformer Blocks
attentions = []
for block in self.blocks:
y, attention = block(y, mask)
attentions.append(attention)
if hidden is not None:
y = y[:, -1:]
return y, hidden_new
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Blocks
from models.blocks import (
ConformerBlock
)
# Modules
from models.modules import (
AudioPreprocessing,
SpecAugment,
Conv1dSubsampling,
Conv2dSubsampling,
Conv2dPoolSubsampling,
VGGSubsampling
)
# Positional Encodings and Masks
from models.attentions import (
SinusoidalPositionalEncoding,
StreamingMask
)
###############################################################################
# Encoder Models
###############################################################################
class ConformerEncoder(nn.Module):
def __init__(self, params):
super(ConformerEncoder, self).__init__()
# Audio Preprocessing
self.preprocessing = AudioPreprocessing(params["sample_rate"], params["n_fft"], params["win_length_ms"], params["hop_length_ms"], params["n_mels"], params["normalize"], params["mean"], params["std"])
# Spec Augment
self.augment = SpecAugment(params["spec_augment"], params["mF"], params["F"], params["mT"], params["pS"])
# Subsampling Module
if params["subsampling_module"] == "Conv1d":
self.subsampling_module = Conv1dSubsampling(params["subsampling_layers"], params["n_mels"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "Conv2d":
self.subsampling_module = Conv2dSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "Conv2dPool":
self.subsampling_module = Conv2dPoolSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "VGG":
self.subsampling_module = VGGSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
else:
raise Exception("Unknown subsampling module:", params["subsampling_module"])
# Padding Mask
self.padding_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0 if params.get("causal", False) else params.get("right_context", params["max_pos_encoding"]))
# Linear Proj
self.linear = nn.Linear(params["subsampling_filters"][-1] * params["n_mels"] // 2**params["subsampling_layers"], params["dim_model"][0] if isinstance(params["dim_model"], list) else params["dim_model"])
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"][0] if isinstance(params["dim_model"], list) else params["dim_model"])
# Conformer Blocks
self.blocks = nn.ModuleList([ConformerBlock(
dim_model=params["dim_model"][(block_id > torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
dim_expand=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"][(block_id > torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["num_heads"], list) else params["num_heads"],
kernel_size=params["kernel_size"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["kernel_size"], list) else params["kernel_size"],
att_group_size=params["att_group_size"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params.get("att_group_size", 1), list) else params.get("att_group_size", 1),
att_kernel_size=params["att_kernel_size"][(block_id > torch.tensor(params.get("strided_layers", []))).sum()] if isinstance(params.get("att_kernel_size", None), list) else params.get("att_kernel_size", None),
linear_att=params.get("linear_att", False),
Pdrop=params["Pdrop"],
relative_pos_enc=params["relative_pos_enc"],
max_pos_encoding=params["max_pos_encoding"] // params.get("stride", 2)**int((block_id > torch.tensor(params.get("strided_blocks", []))).sum()),
conv_stride=(params["conv_stride"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params["conv_stride"], list) else params["conv_stride"]) if block_id in params.get("strided_blocks", []) else 1,
att_stride=(params["att_stride"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params["att_stride"], list) else params["att_stride"]) if block_id in params.get("strided_blocks", []) else 1,
causal=params.get("causal", False)
) for block_id in range(params["num_blocks"])])
def forward(self, x, x_len=None):
# Audio Preprocessing
x, x_len = self.preprocessing(x, x_len)
# Spec Augment
if self.training:
x = self.augment(x, x_len)
# Subsampling Module
x, x_len = self.subsampling_module(x, x_len)
# Padding Mask
mask = self.padding_mask(x, x_len)
# Transpose (B, D, T) -> (B, T, D)
x = x.transpose(1, 2)
# Linear Projection
x = self.linear(x)
# Dropout
x = self.dropout(x)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
x = x + self.pos_enc(x.size(0), x.size(1))
# Conformer Blocks
attentions = []
for block in self.blocks:
x, attention, hidden = block(x, mask)
attentions.append(attention)
# Strided Block
if block.stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if mask is not None:
mask = mask[:, :, ::block.stride, ::block.stride]
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
return x, x_len, attentions
class ConformerEncoderInterCTC(ConformerEncoder):
def __init__(self, params):
super(ConformerEncoderInterCTC, self).__init__(params)
# Inter CTC blocks
self.interctc_blocks = params["interctc_blocks"]
for block_id in params["interctc_blocks"]:
self.__setattr__(
name="linear_expand_" + str(block_id),
value=nn.Linear(
in_features=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
out_features=params["vocab_size"]))
self.__setattr__(
name="linear_proj_" + str(block_id),
value=nn.Linear(
in_features=params["vocab_size"],
out_features=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"]))
def forward(self, x, x_len=None):
# Audio Preprocessing
x, x_len = self.preprocessing(x, x_len)
# Spec Augment
if self.training:
x = self.augment(x, x_len)
# Subsampling Module
x, x_len = self.subsampling_module(x, x_len)
# Padding Mask
mask = self.padding_mask(x, x_len)
# Transpose (B, D, T) -> (B, T, D)
x = x.transpose(1, 2)
# Linear Projection
x = self.linear(x)
# Dropout
x = self.dropout(x)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
x = x + self.pos_enc(x.size(0), x.size(1))
# Conformer Blocks
attentions = []
interctc_probs = []
for block_id, block in enumerate(self.blocks):
x, attention, hidden = block(x, mask)
attentions.append(attention)
# Strided Block
if block.stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if mask is not None:
mask = mask[:, :, ::block.stride, ::block.stride]
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
# Inter CTC Block
if block_id in self.interctc_blocks:
interctc_prob = self.__getattr__("linear_expand_" + str(block_id))(x).softmax(dim=-1)
interctc_probs.append(interctc_prob)
x = x + self.__getattr__("linear_proj_" + str(block_id))(interctc_prob)
return x, x_len, attentions, interctc_probs
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Layers
from models.layers import (
Linear
)
# Activations Functions
from models.activations import (
Swish
)
###############################################################################
# Joint Networks
###############################################################################
class JointNetwork(nn.Module):
def __init__(self, dim_encoder, dim_decoder, vocab_size, params):
super(JointNetwork, self).__init__()
assert params["act"] in ["tanh", "relu", "swish", None]
assert params["joint_mode"] in ["concat", "sum"]
# Model layers
if params["dim_model"] is not None:
# Linear Layers
self.linear_encoder = Linear(dim_encoder, params["dim_model"])
self.linear_decoder = Linear(dim_decoder, params["dim_model"])
# Joint Mode
if params["joint_mode"] == "concat":
self.joint_mode = "concat"
self.linear_joint = Linear(2 * params["dim_model"], vocab_size)
elif params["joint_mode"] == "sum":
self.joint_mode = 'sum'
self.linear_joint = Linear(params["dim_model"], vocab_size)
else:
# Linear Layers
self.linear_encoder = nn.Identity()
self.linear_decoder = nn.Identity()
# Joint Mode
if params["joint_mode"] == "concat":
self.joint_mode = "concat"
self.linear_joint = Linear(dim_encoder + dim_decoder, vocab_size)
elif params["joint_mode"] == "sum":
assert dim_encoder == dim_decoder
self.joint_mode = 'sum'
self.linear_joint = Linear(dim_encoder, vocab_size)
# Model Act Function
if params["act"] == "tanh":
self.act = nn.Tanh()
elif params["act"] == "relu":
self.act = nn.ReLU()
elif params["act"] == "swish":
self.act = Swish()
else:
self.act = nn.Identity()
def forward(self, f, g):
f = self.linear_encoder(f)
g = self.linear_decoder(g)
# Training or Eval Loss
if self.training or (len(f.size()) == 3 and len(g.size()) == 3):
f = f.unsqueeze(2) # (B, T, 1, D)
g = g.unsqueeze(1) # (B, 1, U + 1, D)
f = f.repeat([1, 1, g.size(2), 1]) # (B, T, U + 1, D)
g = g.repeat([1, f.size(1), 1, 1]) # (B, T, U + 1, D)
# Joint Encoder and Decoder
if self.joint_mode == "concat":
joint = torch.cat([f, g], dim=-1) # Training : (B, T, U + 1, 2D) / Decoding : (B, 2D)
elif self.joint_mode == "sum":
joint = f + g # Training : (B, T, U + 1, D) / Decoding : (B, D)
# Act Function
joint = self.act(joint)
# Output Linear Projection
outputs = self.linear_joint(joint) # Training : (B, T, U + 1, V) / Decoding : (B, V)
return outputs
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._VF as _VF
from torch.nn.modules.utils import _single, _pair
# Activation Functions
from models.activations import (
Swish
)
###############################################################################
# Layers
###############################################################################
class Linear(nn.Linear):
def __init__(self, in_features, out_features, bias = True):
super(Linear, self).__init__(
in_features=in_features,
out_features=out_features,
bias=bias)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
return F.linear(input, weight, self.bias)
class Conv1d(nn.Conv1d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride = 1,
padding = "same",
dilation = 1,
groups = 1,
bias = True
):
super(Conv1d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros")
# Assert
assert padding in ["valid", "same", "causal"]
# Padding
if padding == "valid":
self.pre_padding = None
elif padding == "same":
self.pre_padding = nn.ConstantPad1d(padding=((kernel_size - 1) // 2, (kernel_size - 1) // 2), value=0)
elif padding == "causal":
self.pre_padding = nn.ConstantPad1d(padding=(kernel_size - 1, 0), value=0)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Padding
if self.pre_padding is not None:
input = self.pre_padding(input)
# Apply Weight
return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = True, padding_mode = 'zeros'):
super(Conv2d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, self.bias, self.stride, _pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class LSTM(nn.LSTM):
def __init__(self, input_size, hidden_size, num_layers, batch_first, bidirectional):
super(LSTM, self).__init__(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
bidirectional=bidirectional)
# Variational Noise
self.noises = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noises = []
for i in range(0, len(self._flat_weights), 4):
self.noises.append(torch.normal(mean=0.0, std=1.0, size=self._flat_weights[i].size(), device=self._flat_weights[i].device, dtype=self._flat_weights[i].dtype))
self.noises.append(torch.normal(mean=0.0, std=1.0, size=self._flat_weights[i+1].size(), device=self._flat_weights[i+1].device, dtype=self._flat_weights[i+1].dtype))
# Broadcast Noise
if distributed:
for noise in self.noises:
torch.distributed.broadcast(noise, 0)
def forward(self, input, hx=None): # noqa: F811
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, nn.utils.rnn.PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
# Add Noise
if self.noises is not None and self.training:
weight = []
for i in range(0, len(self.noises), 2):
weight.append(self._flat_weights[2*i] + self.vn_std * self.noises[i])
weight.append(self._flat_weights[2*i+1] + self.vn_std * self.noises[i+1])
weight.append(self._flat_weights[2*i+2])
weight.append(self._flat_weights[2*i+3])
else:
weight = self._flat_weights
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.lstm(input, hx, weight, self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.lstm(input, batch_sizes, hx, weight, self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1:]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, nn.utils.rnn.PackedSequence):
output_packed = nn.utils.rnn.PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
return output, self.permute_hidden(hidden, unsorted_indices)
class Embedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx = None):
super(Embedding, self).__init__(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
return F.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
class IdentityProjection(nn.Module):
def __init__(self, input_dim, output_dim):
super(IdentityProjection, self).__init__()
assert output_dim > input_dim
self.linear = Linear(input_dim, output_dim - input_dim)
def forward(self, x):
# (B, T, Dout - Din)
proj = self.linear(x)
# (B, T, Dout)
x = torch.cat([x, proj], dim=-1)
return x
class DepthwiseSeparableConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(DepthwiseSeparableConv1d, self).__init__()
# Layers
self.layers = nn.Sequential(
Conv1d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, stride=stride),
Conv1d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm1d(out_channels),
Swish()
)
def forward(self, x):
return self.layers(x)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Base Model
from models.model import Model
# Decoder
from models.decoders import (
RnnDecoder,
TransformerDecoder
)
# Losses
from models.losses import (
LossCE
)
class LanguageModel(Model):
def __init__(self, lm_params, tokenizer_params, training_params, decoding_params, name):
super(LanguageModel, self).__init__(tokenizer_params, training_params, decoding_params, name)
# Language Model
if lm_params["arch"] == "RNN":
self.decoder = RnnDecoder(lm_params)
elif lm_params["arch"] == "Transformer":
self.decoder = TransformerDecoder(lm_params)
else:
raise Exception("Unknown model architecture:", lm_params["arch"])
# FC Layer
self.fc = nn.Linear(lm_params["dim_model"], tokenizer_params["vocab_size"])
# Criterion
self.criterion = LossCE()
# Compile
self.compile(training_params)
def decode(self, x, hidden):
# Text Decoder (1, 1) -> (1, 1, Dlm)
logits, hidden = self.decoder(x, hidden)
# FC Layer (1, 1, Dlm) -> (1, 1, V)
logits = self.fc(logits)
return logits, hidden
def forward(self, batch):
# Unpack Batch
x, x_len, y = batch
# Add blank token
x = torch.nn.functional.pad(x, pad=(1, 0, 0, 0), value=0)
if x_len is not None:
x_len = x_len + 1
# Text Decoder (B, U + 1) -> (B, U + 1, Dlm)
logits, _ = self.decoder(x, None, x_len)
# FC Layer (B, U + 1, Dlm) -> (B, U + 1, V)
logits = self.fc(logits)
return logits
def gready_search_decoding(self, x, x_len):
return [""]
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# RNN-T Loss
import warp_rnnt
class LossRNNT(nn.Module):
def __init__(self):
super(LossRNNT, self).__init__()
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _ = pred
# Compute Loss
loss = warp_rnnt.rnnt_loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1),
labels=y.int(),
frames_lengths=f_len.int(),
labels_lengths=y_len.int(),
average_frames=False,
reduction='mean',
blank=0,
gather=True)
return loss
class LossCTC(nn.Module):
def __init__(self):
super(LossCTC, self).__init__()
# CTC Loss
self.loss = nn.CTCLoss(blank=0, reduction="none", zero_infinity=False)
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _ = pred
# Compute Loss
loss = self.loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1).transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len).mean()
return loss
class LossInterCTC(nn.Module):
def __init__(self, interctc_lambda):
super(LossInterCTC, self).__init__()
# CTC Loss
self.loss = nn.CTCLoss(blank=0, reduction="none", zero_infinity=False)
# InterCTC Lambda
self.interctc_lambda = interctc_lambda
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _, interctc_probs = pred
# Compute CTC Loss
loss_ctc = self.loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1).transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len)
# Compute Inter Loss
loss_inter = sum(self.loss(
log_probs=interctc_prob.log().transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len) for interctc_prob in interctc_probs) / len(interctc_probs)
# Compute total Loss
loss = (1 - self.interctc_lambda) * loss_ctc + self.interctc_lambda * loss_inter
loss = loss.mean()
return loss
class LossCE(nn.Module):
def __init__(self):
super(LossCE, self).__init__()
# CE Loss
self.loss = nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-1, reduce=None, reduction='mean')
def forward(self, batch, pred):
# Unpack Batch
x, x_len, y = batch
# Unpack Predictions
outputs_pred = pred
# Compute Loss
loss = self.loss(
input=outputs_pred.transpose(1, 2),
target=y)
return loss
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
# Sentencepiece
import sentencepiece as spm
# Schedulers
from models.schedules import *
# Other
from tqdm import tqdm
import jiwer
import os
import time
def sample_synaptic_noise(m, distributed):
if hasattr(m, "sample_synaptic_noise"):
m.sample_synaptic_noise(distributed)
def init_vn(m, vn_std):
if hasattr(m, "init_vn"):
m.init_vn(vn_std)
class Model(nn.Module):
def __init__(self, tokenizer_params, training_params, decoding_params, name):
super(Model, self).__init__()
# Tokenizer
try:
self.tokenizer = spm.SentencePieceProcessor(tokenizer_params["tokenizer_path"])
except:
self.tokenizer = None
print("Tokenizer not found...")
# Training Params
self.encoder_frozen_steps = training_params.get("encoder_frozen_steps", None)
self.vn_start_step = training_params.get("vn_start_step", None)
# Decoding Params
self.beam_size = decoding_params.get("beam_size", 1)
self.tmp = decoding_params.get("tmp", 1)
# Ngram
self.ngram_path = decoding_params.get("ngram_path", None)
self.ngram_alpha = decoding_params.get("ngram_alpha", 0)
self.ngram_beta = decoding_params.get("ngram_beta", 0)
self.ngram_offset = decoding_params.get("ngram_offset", 100)
# LM
self.lm = None
self.lm_weight = decoding_params.get("lm_weight", 0)
self.lm_tmp = decoding_params.get("lm_tmp", 1)
# Distributed Computing
self.is_distributed = False
self.rank = 0
self.is_parallel = False
# Model Name
self.name = name
def compile(self, training_params):
# Optimizers
if training_params["optimizer"] == "Adam":
# Adam
self.optimizer = optim.Adam(
params=self.parameters(),
lr=0,
betas=(training_params["beta1"], training_params["beta2"]),
eps=training_params["eps"],
weight_decay=training_params["weight_decay"])
elif training_params["optimizer"] == "SGD":
# SGD
self.optimizer = optim.SGD(
params=self.parameters,
lr=0,
momentum=training_params["momentum"],
weight_decay=training_params["weight_decay"])
# LR Schedulers
if training_params["lr_schedule"] == "Constant":
# Constant LR
self.scheduler = constant_learning_rate_scheduler(
optimizer=self.optimizer,
lr_value=training_params["lr_value"])
elif training_params["lr_schedule"] == "ConstantWithDecay":
# Constant With Decay LR
self.scheduler = constant_with_decay_learning_rate_scheduler(
optimizer=self.optimizer,
lr_values=training_params["lr_values"],
decay_steps=training_params["decay_steps"])
elif training_params["lr_schedule"] == "Transformer":
# Transformer LR
self.scheduler = transformer_learning_rate_scheduler(
optimizer=self.optimizer,
dim_model=training_params["schedule_dim"],
warmup_steps=training_params["warmup_steps"],
K=training_params["K"])
elif training_params["lr_schedule"] == "ExpDecayTransformer":
# Exp Decay Transformer LR
self.scheduler = exponential_decay_transformer_learning_rate_scheduler(
optimizer=self.optimizer,
warmup_steps=training_params["warmup_steps"],
lr_max=training_params["lr_max"] if training_params.get("lr_max", None) else training_params["K"] * training_params["schedule_dim"]**-0.5 * training_params["warmup_steps"]**-0.5,
alpha=training_params["alpha"],
end_step=training_params["end_step"])
elif training_params["lr_schedule"] == "Cosine":
# Cosine Annealing LR
self.scheduler = cosine_annealing_learning_rate_scheduler(
optimizer=self.optimizer,
warmup_steps=training_params["warmup_steps"],
lr_max=training_params["lr_max"] if training_params.get("lr_max", None) else training_params["K"] * training_params["schedule_dim"]**-0.5 * training_params["warmup_steps"]**-0.5,
lr_min= training_params["lr_min"],
end_step=training_params["end_step"])
# Init LR
self.scheduler.step()
def num_params(self):
return sum([p.numel() for p in self.parameters()])
def summary(self, show_dict=False):
print(self.name)
print("Model Parameters :", self.num_params())
if show_dict:
for key, value in self.state_dict().items():
print("{:<64} {:<16} mean {:<16.4f} std {:<16.4f}".format(key, str(tuple(value.size())), value.float().mean(), value.float().std()))
def distribute_strategy(self, rank):
self.rank = rank
self.is_distributed = True
def parallel_strategy(self):
self.is_parallel = True
def fit(self, dataset_train, epochs, dataset_val=None, val_steps=None, verbose_val=False, initial_epoch=0, callback_path=None, steps_per_epoch=None, mixed_precision=False, accumulated_steps=1, saving_period=1, val_period=1):
# Model Device
device = next(self.parameters()).device
# Mixed Precision Gradient Scaler
scaler = torch.cuda.amp.GradScaler(enabled=mixed_precision)
# Init Training
acc_step = 0
self.optimizer.zero_grad()
# Callbacks
if self.rank == 0 and callback_path is not None:
# Create Callbacks
if not os.path.isdir(callback_path):
os.makedirs(callback_path)
# Create Writer
writer = SummaryWriter(callback_path + "logs")
else:
writer = None
# Sample Synaptic Noise
if self.vn_start_step is not None:
if self.scheduler.model_step >= self.vn_start_step:
self.decoder.apply(lambda m: sample_synaptic_noise(m, self.is_distributed))
# Try Catch
try:
# Training Loop
for epoch in range(initial_epoch, epochs):
# Sync sampler if distributed
if self.is_distributed:
dataset_train.sampler.set_epoch(epoch)
# Epoch Init
if self.rank == 0:
print("Epoch {}/{}".format(epoch + 1, epochs))
epoch_iterator = tqdm(dataset_train, total=steps_per_epoch * accumulated_steps if steps_per_epoch else None)
else:
epoch_iterator = dataset_train
epoch_loss = 0.0
# Training Mode
self.train()
# Epoch training
for step, batch in enumerate(epoch_iterator):
# Load batch to model device
batch = [elt.to(device) for elt in batch]
# Encoder Frozen Steps
if self.encoder_frozen_steps:
if self.scheduler.model_step > self.encoder_frozen_steps:
self.encoder.requires_grad_(True)
else:
self.encoder.requires_grad_(False)
# Automatic Mixed Precision Casting (model prediction + loss computing)
with torch.cuda.amp.autocast(enabled=mixed_precision):
pred = self.forward(batch)
loss_mini = self.criterion(batch, pred)
loss = loss_mini / accumulated_steps
# Accumulate gradients
scaler.scale(loss).backward()
# Update Epoch Variables
acc_step += 1
epoch_loss += loss_mini.detach()
# Continue Accumulating
if acc_step < accumulated_steps:
continue
# Update Parameters, Zero Gradients and Update Learning Rate
scaler.step(self.optimizer)
scaler.update()
self.optimizer.zero_grad()
self.scheduler.step()
acc_step = 0
# Sample Synaptic Noise
if self.vn_start_step is not None:
if self.scheduler.model_step >= self.vn_start_step:
self.decoder.apply(lambda m: sample_synaptic_noise(m, self.is_distributed))
# Step Print
if self.rank == 0:
epoch_iterator.set_description("model step: {} - mean loss {:.4f} - batch loss: {:.4f} - learning rate: {:.6f}".format(self.scheduler.model_step, epoch_loss / (step + 1), loss_mini, self.optimizer.param_groups[0]['lr']))
# Logs Step
if self.rank == 0 and writer is not None and (step + 1) % 10 == 0:
writer.add_scalar('Training/Loss', loss_mini, self.scheduler.model_step)
writer.add_scalar('Training/LearningRate', self.optimizer.param_groups[0]['lr'], self.scheduler.model_step)
# Step per Epoch
if steps_per_epoch is not None:
if step + 1 >= steps_per_epoch * accumulated_steps:
break
# Reduce Epoch Loss among devices
if self.is_distributed:
torch.distributed.barrier()
torch.distributed.all_reduce(epoch_loss)
epoch_loss /= torch.distributed.get_world_size()
# Logs Epoch
if self.rank == 0 and writer is not None:
writer.add_scalar('Training/MeanLoss', epoch_loss / (steps_per_epoch * accumulated_steps if steps_per_epoch is not None else dataset_train.__len__()), epoch + 1)
# Validation
if (epoch + 1) % val_period == 0:
# Validation Dataset
if dataset_val:
# Multiple Validation Datasets
if isinstance(dataset_val, dict):
for dataset_name, dataset in dataset_val.items():
# Evaluate
wer, truths, preds, val_loss = self.evaluate(dataset, val_steps, verbose_val, eval_loss=True)
# Print wer
if self.rank == 0:
print("{} wer : {:.2f}% - loss : {:.4f}".format(dataset_name, 100 * wer, val_loss))
# Logs Validation
if self.rank == 0 and writer is not None:
writer.add_scalar('Validation/WER/{}'.format(dataset_name), 100 * wer, epoch + 1)
writer.add_scalar('Validation/MeanLoss/{}'.format(dataset_name), val_loss, epoch + 1)
writer.add_text('Validation/Predictions/{}'.format(dataset_name), "GroundTruth : " + truths[0] + " / Prediction : " + preds[0], epoch + 1)
else:
# Evaluate
wer, truths, preds, val_loss = self.evaluate(dataset_val, val_steps, verbose_val, eval_loss=True)
# Print wer
if self.rank == 0:
print("Val wer : {:.2f}% - Val loss : {:.4f}".format(100 * wer, val_loss))
# Logs Validation
if self.rank == 0 and writer is not None:
writer.add_scalar('Validation/WER', 100 * wer, epoch + 1)
writer.add_scalar('Validation/MeanLoss', val_loss, epoch + 1)
writer.add_text('Validation/Predictions', "GroundTruth : " + truths[0] + " / Prediction : " + preds[0], epoch + 1)
# Saving Checkpoint
if (epoch + 1) % saving_period == 0:
if callback_path and self.rank == 0:
self.save(callback_path + "checkpoints_" + str(epoch + 1) + ".ckpt")
# Exception Handler
except Exception as e:
if self.is_distributed:
torch.distributed.destroy_process_group()
if self.rank == 0 and writer is not None:
writer.add_text('Exceptions', str(e))
raise e
def save(self, path, save_optimizer=True):
# Save Model Checkpoint
torch.save({
"model_state_dict": self.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict() if save_optimizer else None,
"model_step": self.scheduler.model_step,
"tokenizer": self.tokenizer,
"is_distributed": self.is_distributed or self.is_parallel
}, path)
# Print Model state
if self.rank == 0:
print("model saved at step {} / lr {:.6f}".format(self.scheduler.model_step, self.optimizer.param_groups[0]['lr']))
def load(self, path):
# Load Model Checkpoint
checkpoint = torch.load(path, map_location=next(self.parameters()).device)
# Model State Dict
if checkpoint["is_distributed"] and not self.is_distributed:
self.load_state_dict({key.replace(".module.", "."):value for key, value in checkpoint["model_state_dict"].items()})
else:
self.load_state_dict({key:value for key, value in checkpoint["model_state_dict"].items()})
# Model Step
self.scheduler.model_step = checkpoint["model_step"]
# Optimizer State Dict
if checkpoint["optimizer_state_dict"] is not None:
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Tokenizer
self.tokenizer = checkpoint["tokenizer"]
# Print Model state
if self.rank == 0:
print("model loaded at step {} / lr {:.6f}".format(self.scheduler.model_step, self.optimizer.param_groups[0]['lr']))
def evaluate(self, dataset_eval, eval_steps=None, verbose=False, beam_size=1, eval_loss=True):
# Evaluzation Mode
self.eval()
# Model Device
device = next(self.parameters()).device
# Groundtruth / Prediction string lists
speech_true = []
speech_pred = []
# Total wer / loss
total_wer = 0.0
total_loss = 0.0
# tqdm Iterator
if self.rank == 0:
eval_iterator = tqdm(dataset_eval, total=eval_steps)
else:
eval_iterator = dataset_eval
# Evaluation Loop
for step, batch in enumerate(eval_iterator):
batch = [elt.to(device) for elt in batch]
# Sequence Prediction
with torch.no_grad():
if beam_size > 1:
outputs_pred = self.beam_search_decoding(batch[0], batch[2], beam_size)
else:
outputs_pred = self.gready_search_decoding(batch[0], batch[2])
# Sequence Truth
outputs_true = self.tokenizer.decode(batch[1].tolist())
# Compute Batch wer and Update total wer
batch_wer = jiwer.wer(outputs_true, outputs_pred, standardize=True)
total_wer += batch_wer
# Update String lists
speech_true += outputs_true
speech_pred += outputs_pred
# Prediction Verbose
if verbose:
print("Groundtruths :\n", outputs_true)
print("Predictions :\n", outputs_pred)
# Eval Loss
if eval_loss:
with torch.no_grad():
pred = self.forward(batch)
batch_loss = self.criterion(batch, pred)
total_loss += batch_loss
# Step print
if self.rank == 0:
if eval_loss:
eval_iterator.set_description("mean batch wer {:.2f}% - batch wer: {:.2f}% - mean loss {:.4f} - batch loss: {:.4f}".format(100 * total_wer / (step + 1), 100 * batch_wer, total_loss / (step + 1), batch_loss))
else:
eval_iterator.set_description("mean batch wer {:.2f}% - batch wer: {:.2f}%".format(100 * total_wer / (step + 1), 100 * batch_wer))
# Evaluation Steps
if eval_steps:
if step + 1 >= eval_steps:
break
# Reduce wer among devices
if self.is_distributed:
# Process Barrier
torch.distributed.barrier()
# All Gather Speech Truths and Predictions
speech_true_gather = [None for _ in range(torch.distributed.get_world_size())]
speech_pred_gather = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(speech_true_gather, speech_true)
torch.distributed.all_gather_object(speech_pred_gather, speech_pred)
speech_true = []
speech_pred = []
for truth in speech_true_gather:
speech_true += truth
for pred in speech_pred_gather:
speech_pred += pred
# All Reduce Total loss
if eval_loss:
torch.distributed.all_reduce(total_loss)
total_loss /= torch.distributed.get_world_size()
# Compute wer
if total_wer / (eval_steps if eval_steps is not None else dataset_eval.__len__()) > 1:
wer = 1
else:
wer = jiwer.wer(speech_true, speech_pred, standardize=True)
# Compute loss
if eval_loss:
loss = total_loss / (eval_steps if eval_steps is not None else dataset_eval.__len__())
# Return word error rate, groundtruths and predictions
return wer, speech_true, speech_pred, loss if eval_loss else None
def swa(self, dataset, callback_path, start_epoch, end_epoch, epochs_list=None, update_steps=None, swa_type="equal", swa_decay=0.9):
# Model device
device = next(self.parameters()).device
# Create SWA Model
if swa_type == "equal":
swa_model = torch.optim.swa_utils.AveragedModel(self)
elif swa_type == "exp":
swa_model = torch.optim.swa_utils.AveragedModel(self, avg_fn=lambda averaged_model_parameter, model_parameter, num_averaged: (1 - swa_decay) * averaged_model_parameter + swa_decay * model_parameter)
if self.rank == 0:
if epochs_list:
print("Stochastic Weight Averaging on checkpoints : {}".format(epochs_list))
else:
print("Stochastic Weight Averaging on checkpoints : {}-{}".format(start_epoch, end_epoch))
# Update SWA Model Params
if epochs_list:
for epoch in epochs_list:
# Load Model Checkpoint
self.load(callback_path + "checkpoints_" + str(epoch) + ".ckpt")
# Update SWA Model
swa_model.update_parameters(self)
else:
for epoch in range(int(start_epoch), int(end_epoch) + 1):
# Load Model Checkpoint
self.load(callback_path + "checkpoints_" + str(epoch) + ".ckpt")
# Update SWA Model
swa_model.update_parameters(self)
# Load SWA Model Params
self.load_state_dict({key[7:]:value for key, value in swa_model.state_dict().items() if key != "n_averaged"})
if self.rank == 0:
print("Updating Batch Normalization Statistics")
# Init
self.train()
if self.rank == 0:
dataset_iterator = tqdm(dataset, total=update_steps)
else:
dataset_iterator = dataset
# Update Batch Normalization Statistics
for step, batch in enumerate(dataset_iterator):
# Load batch to model device
batch = [elt.to(device) for elt in batch]
# Forward Encoder
with torch.cuda.amp.autocast(enabled=True):
with torch.no_grad():
self.encoder.forward(batch[0], batch[2])
# update_steps
if update_steps is not None:
if step + 1 == update_steps:
break
# Save Model
if self.rank == 0:
if epochs_list:
self.save(callback_path + "checkpoints_swa-" + swa_type + "-" + "list" + "-" + epochs_list[0] + "-" + epochs_list[-1] + ".ckpt", save_optimizer=False)
else:
self.save(callback_path + "checkpoints_swa-" + swa_type + "-" + start_epoch + "-" + end_epoch + ".ckpt", save_optimizer=False)
# Barrier
if self.is_distributed:
torch.distributed.barrier()
def eval_time(self, dataset_eval, eval_steps=None, beam_size=1, rnnt_max_consec_dec_steps=None, profiler=False):
def decode():
# Start Timer
start = time.time()
# Evaluation Loop
for step, batch in enumerate(eval_iterator):
batch = [elt.to(device) for elt in batch]
# Sequence Prediction
with torch.no_grad():
if beam_size > 1:
outputs_pred = self.beam_search_decoding(batch[0], batch[2], beam_size)
else:
if rnnt_max_consec_dec_steps is not None:
outputs_pred = self.gready_search_decoding(batch[0], batch[2], rnnt_max_consec_dec_steps)
else:
outputs_pred = self.gready_search_decoding(batch[0], batch[2])
# Evaluation Steps
if eval_steps:
if step + 1 >= eval_steps:
break
# Stop Timer
return time.time() - start
# Model Device
device = next(self.parameters()).device
# Evaluzation Mode
self.eval()
# tqdm Iterator
if self.rank == 0:
eval_iterator = tqdm(dataset_eval, total=eval_steps)
else:
eval_iterator = dataset_eval
# Decoding
if profiler:
with torch.autograd.profiler.profile(profile_memory=True) as prof:
with torch.autograd.profiler.record_function("Model Inference"):
timer = decode()
else:
timer = decode()
# Profiler Print
if profiler:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# Return Eval Time in s
return timer
def eval_time_encoder(self, dataset_eval, eval_steps=None, profiler=False):
def forward():
# Start Timer
start = time.time()
for step, batch in enumerate(eval_iterator):
batch = [elt.to(device) for elt in batch]
with torch.no_grad():
x, x_len, att = self.encoder.forward(batch[0], batch[2])
# Evaluation Steps
if eval_steps:
if step + 1 >= eval_steps:
break
# Stop Timer
return time.time() - start
# Model Device
device = next(self.parameters()).device
# Evaluzation Mode
self.eval()
# tqdm Iterator
if self.rank == 0:
eval_iterator = tqdm(dataset_eval, total=eval_steps)
else:
eval_iterator = dataset_eval
# Forward
if profiler:
with torch.autograd.profiler.profile(profile_memory=True) as prof:
with torch.autograd.profiler.record_function("Model Inference"):
timer = forward()
else:
timer = forward()
# Profiler Print
if profiler:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# Return Eval Time in s
return timer
def eval_time_decoder(self, dataset_eval, eval_steps=None, profiler=False):
def forward():
# Start Timer
start = time.time()
for step, batch in enumerate(eval_iterator):
batch = [elt.to(device) for elt in batch]
hidden = None
for i in range(batch[1].size(1)):
with torch.no_grad():
_, hidden = self.decoder.forward(batch[1][:, i:i+1], hidden)
# Evaluation Steps
if eval_steps:
if step + 1 >= eval_steps:
break
# Stop Timer
return time.time() - start
# Model Device
device = next(self.parameters()).device
# Evaluzation Mode
self.eval()
# tqdm Iterator
if self.rank == 0:
eval_iterator = tqdm(dataset_eval, total=eval_steps)
else:
eval_iterator = dataset_eval
# Forward
if profiler:
with torch.autograd.profiler.profile(profile_memory=True) as prof:
with torch.autograd.profiler.record_function("Model Inference"):
timer = forward()
else:
timer = forward()
# Profiler Print
if profiler:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# Return Eval Time in s
return timer
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Base Model
from models.model import Model
# Encoders
from models.encoders import (
ConformerEncoder,
ConformerEncoderInterCTC
)
# Losses
from models.losses import (
LossCTC,
LossInterCTC
)
# CTC Decode Beam Search
from ctcdecode import CTCBeamDecoder
class ModelCTC(Model):
def __init__(self, encoder_params, tokenizer_params, training_params, decoding_params, name):
super(ModelCTC, self).__init__(tokenizer_params, training_params, decoding_params, name)
# Encoder
if encoder_params["arch"] == "Conformer":
self.encoder = ConformerEncoder(encoder_params)
else:
raise Exception("Unknown encoder architecture:", encoder_params["arch"])
# FC Layer
self.fc = nn.Linear(encoder_params["dim_model"][-1] if isinstance(encoder_params["dim_model"], list) else encoder_params["dim_model"], tokenizer_params["vocab_size"])
# Criterion
self.criterion = LossCTC()
# Compile
self.compile(training_params)
def forward(self, batch):
# Unpack Batch
x, _, x_len, _ = batch
# Forward Encoder (B, Taud) -> (B, T, Denc)
logits, logits_len, attentions = self.encoder(x, x_len)
# FC Layer (B, T, Denc) -> (B, T, V)
logits = self.fc(logits)
return logits, logits_len, attentions
def distribute_strategy(self, rank):
super(ModelCTC, self).distribute_strategy(rank)
self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder, device_ids=[self.rank])
self.fc = torch.nn.parallel.DistributedDataParallel(self.fc, device_ids=[self.rank])
def load_encoder(self, path):
# Load Encoder Params
checkpoint = torch.load(path, map_location=next(self.parameters()).device)
if checkpoint["is_distributed"] and not self.is_distributed:
self.encoder.load_state_dict({key.replace(".module.", ".").replace("encoder.", ""):value for key, value in checkpoint["model_state_dict"].items() if key[:len("encoder")] == "encoder"})
else:
self.encoder.load_state_dict({key.replace("encoder.", ""):value for key, value in checkpoint["model_state_dict"].items() if key[:len("encoder")] == "encoder"})
# Print Encoder state
if self.rank == 0:
print("Model encoder loaded at step {} from {}".format(checkpoint["model_step"], path))
def gready_search_decoding(self, x, x_len):
# Forward Encoder (B, Taud) -> (B, T, Denc)
logits, logits_len = self.encoder(x, x_len)[:2]
# FC Layer (B, T, Denc) -> (B, T, V)
logits = self.fc(logits)
# Softmax -> Log > Argmax -> (B, T)
preds = logits.log_softmax(dim=-1).argmax(dim=-1)
# Batch Pred List
batch_pred_list = []
# Batch loop
for b in range(logits.size(0)):
# Blank
blank = False
# Pred List
pred_list = []
# Decoding Loop
for t in range(logits_len[b]):
# Blank Prediction
if preds[b, t] == 0:
blank = True
continue
# First Prediction
if len(pred_list) == 0:
pred_list.append(preds[b, t].item())
# New Prediction
elif pred_list[-1] != preds[b, t] or blank:
pred_list.append(preds[b, t].item())
# Update Blank
blank = False
# Append Sequence
batch_pred_list.append(pred_list)
# Decode Sequences
return self.tokenizer.decode(batch_pred_list)
def beam_search_decoding(self, x, x_len, beam_size=None):
# Overwrite beam size
if beam_size is None:
beam_size = self.beam_size
# Beam Search Decoder
decoder = CTCBeamDecoder(
[chr(idx + self.ngram_offset) for idx in range(self.tokenizer.vocab_size())],
model_path=self.ngram_path,
alpha=self.ngram_alpha,
beta=self.ngram_beta,
cutoff_top_n=self.tokenizer.vocab_size(),
cutoff_prob=1.0,
beam_width=beam_size,
num_processes=8,
blank_id=0,
log_probs_input=True
)
# Forward Encoder (B, Taud) -> (B, T, Denc)
logits, logits_len = self.encoder(x, x_len)[:2]
# FC Layer (B, T, Denc) -> (B, T, V)
logits = self.fc(logits)
# Apply Temperature
logits = logits / self.tmp
# Softmax -> Log
logP = logits.softmax(dim=-1).log()
# Beam Search Decoding
beam_results, beam_scores, timesteps, out_lens = decoder.decode(logP, logits_len)
# Batch Pred List
batch_pred_list = []
# Batch loop
for b in range(logits.size(0)):
batch_pred_list.append(beam_results[b][0][:out_lens[b][0]].tolist())
# Decode Sequences
return self.tokenizer.decode(batch_pred_list)
class InterCTC(ModelCTC):
def __init__(self, encoder_params, tokenizer_params, training_params, decoding_params, name):
super(ModelCTC, self).__init__(tokenizer_params, training_params, name)
# Update Encoder Params
encoder_params["vocab_size"] = tokenizer_params["vocab_size"]
# Encoder
if encoder_params["arch"] == "Conformer":
self.encoder = ConformerEncoderInterCTC(encoder_params)
# FC Layer
self.fc = nn.Linear(encoder_params["dim_model"][-1] if isinstance(encoder_params["dim_model"], list) else encoder_params["dim_model"], tokenizer_params["vocab_size"])
# Criterion
self.criterion = LossInterCTC(training_params["interctc_lambda"])
# Compile
self.compile(training_params)
def forward(self, batch):
# Unpack Batch
x, _, x_len, _ = batch
# Forward Encoder (B, Taud) -> (B, T, Denc)
logits, logits_len, attentions, interctc_logits = self.encoder(x, x_len)
# FC Layer (B, T, Denc) -> (B, T, V)
logits = self.fc(logits)
return logits, logits_len, attentions, interctc_logits
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Base Model
from models.model import Model
# Encoders
from models.encoders import (
ConformerEncoder
)
# Decoders
from models.decoders import (
ConformerCrossDecoder,
TransformerCrossDecoder
)
# Losses
from models.losses import (
LossCE
)
# Ngram
import kenlm
class ModelS2S(Model):
def __init__(self, encoder_params, decoder_params, tokenizer_params, training_params, decoding_params, name):
super(ModelS2S, self).__init__(tokenizer_params, training_params, decoding_params, name)
# Not Implemented
raise Exception("Sequence-to-sequence model not implemented")
# Encoder
if encoder_params["arch"] == "Conformer":
self.encoder = ConformerEncoder(encoder_params)
else:
raise Exception("Unknown encoder architecture:", encoder_params["arch"])
# Decoder
if decoder_params["arch"] == "Conformer":
self.decoder = ConformerCrossDecoder(decoder_params)
elif decoder_params["arch"] == "Transformer":
self.decoder = TransformerCrossDecoder(decoder_params)
else:
raise Exception("Unknown decoder architecture:", decoder_params["arch"])
# Joint Network
self.fc = nn.Linear(encoder_params["dim_model"][-1] if isinstance(encoder_params["dim_model"], list) else encoder_params["dim_model"], tokenizer_params["vocab_size"])
# Criterion
self.criterion = LossCE()
# Compile
self.compile(training_params)
def forward(self, batch):
# Unpack Batch
x, y, _ = batch
# Audio Encoder (B, Taud) -> (B, T, Denc)
x, _, attentions = self.encoder(x, None)
# Add blank token
y = torch.nn.functional.pad(y, pad=(1, 0, 0, 0), value=0)
# Text Decoder (B, U + 1) -> (B, U + 1, Ddec)
y = self.decoder(x, y)
# FC Layer (B, T, Ddec) -> (B, T, V)
logits = self.fc(y)
return logits, attentions
def distribute_strategy(self, rank):
super(ModelS2S, self).distribute_strategy(rank)
self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder, device_ids=[self.rank])
self.decoder = torch.nn.parallel.DistributedDataParallel(self.decoder, device_ids=[self.rank])
self.fc = torch.nn.parallel.DistributedDataParallel(self.fc, device_ids=[self.rank])
def parallel_strategy(self):
super(ModelS2S, self).parallel_strategy()
self.encoder = torch.nn.DataParallel(self.encoder)
self.decoder = torch.nn.DataParallel(self.decoder)
self.fc = torch.nn.DataParallel(self.fc)
def summary(self, show_dict=False):
print(self.name)
print("Model Parameters :", self.num_params() - self.lm.num_params() if self.lm else self.num_params())
print(" - Encoder Parameters :", sum([p.numel() for p in self.encoder.parameters()]))
print(" - Decoder Parameters :", sum([p.numel() for p in self.decoder.parameters()]))
print(" - Joint Parameters :", sum([p.numel() for p in self.joint_network.parameters()]))
if self.lm:
print("LM Parameters :", self.lm.num_params())
if show_dict:
for key, value in self.state_dict().items():
print("{:<64} {:<16} mean {:<16.4f} std {:<16.4f}".format(key, str(tuple(value.size())), value.float().mean(), value.float().std()))
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
import torchaudio
# Attentions
from models.attentions import (
# Abs Attentions
MultiHeadAttention,
GroupedMultiHeadAttention,
LocalMultiHeadAttention,
StridedMultiHeadAttention,
StridedLocalMultiHeadAttention,
MultiHeadLinearAttention,
# Rel Attentions
RelPosMultiHeadSelfAttention,
GroupedRelPosMultiHeadSelfAttention,
LocalRelPosMultiHeadSelfAttention,
StridedRelPosMultiHeadSelfAttention,
StridedLocalRelPosMultiHeadSelfAttention
)
# Layers
from models.layers import (
Linear,
Conv1d,
Transpose,
DepthwiseSeparableConv1d
)
# Activations
from models.activations import (
Swish,
Glu
)
###############################################################################
# Audio Preprocessing
###############################################################################
class AudioPreprocessing(nn.Module):
"""Audio Preprocessing
Computes mel-scale log filter banks spectrogram
Args:
sample_rate: Audio sample rate
n_fft: FFT frame size, creates n_fft // 2 + 1 frequency bins.
win_length_ms: FFT window length in ms, must be <= n_fft
hop_length_ms: length of hop between FFT windows in ms
n_mels: number of mel filter banks
normalize: whether to normalize mel spectrograms outputs
mean: training mean
std: training std
Shape:
Input: (batch_size, audio_len)
Output: (batch_size, n_mels, audio_len // hop_length + 1)
"""
def __init__(self, sample_rate, n_fft, win_length_ms, hop_length_ms, n_mels, normalize, mean, std):
super(AudioPreprocessing, self).__init__()
self.win_length = int(sample_rate * win_length_ms) // 1000
self.hop_length = int(sample_rate * hop_length_ms) // 1000
self.Spectrogram = torchaudio.transforms.Spectrogram(n_fft, self.win_length, self.hop_length)
self.MelScale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min=0, f_max=8000, n_stft=n_fft // 2 + 1)
self.normalize = normalize
self.mean = mean
self.std = std
def forward(self, x, x_len):
# Short Time Fourier Transform (B, T) -> (B, n_fft // 2 + 1, T // hop_length + 1)
x = self.Spectrogram(x)
# Mel Scale (B, n_fft // 2 + 1, T // hop_length + 1) -> (B, n_mels, T // hop_length + 1)
x = self.MelScale(x)
# Energy log, autocast disabled to prevent float16 overflow
x = (x.float() + 1e-9).log().type(x.dtype)
# Compute Sequence lengths
if x_len is not None:
x_len = x_len // self.hop_length + 1
# Normalize
if self.normalize:
x = (x - self.mean) / self.std
return x, x_len
class SpecAugment(nn.Module):
"""Spectrogram Augmentation
Args:
spec_augment: whether to apply spec augment
mF: number of frequency masks
F: maximum frequency mask size
mT: number of time masks
pS: adaptive maximum time mask size in %
References:
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition, Park et al.
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets, Park et al.
https://arxiv.org/abs/1912.05533
"""
def __init__(self, spec_augment, mF, F, mT, pS):
super(SpecAugment, self).__init__()
self.spec_augment = spec_augment
self.mF = mF
self.F = F
self.mT = mT
self.pS = pS
def forward(self, x, x_len):
# Spec Augment
if self.spec_augment:
# Frequency Masking
for _ in range(self.mF):
x = torchaudio.transforms.FrequencyMasking(freq_mask_param=self.F, iid_masks=False).forward(x)
# Time Masking
for b in range(x.size(0)):
T = int(self.pS * x_len[b])
for _ in range(self.mT):
x[b, :, :x_len[b]] = torchaudio.transforms.TimeMasking(time_mask_param=T).forward(x[b, :, :x_len[b]])
return x
###############################################################################
# Conv Subsampling Modules
###############################################################################
class Conv1dSubsampling(nn.Module):
"""Conv1d Subsampling Block
Args:
num_layers: number of strided convolution layers
in_dim: input feature dimension
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_dim, in_length)
Output: (batch_size, out_dim, out_length)
"""
def __init__(self, num_layers, in_dim, filters, kernel_size, norm, act):
super(Conv1dSubsampling, self).__init__()
# Assert
assert norm in ["batch", "layer", "none"]
assert act in ["relu", "swish", "none"]
# Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv1d(in_dim if layer_id == 0 else filters[layer_id - 1], filters[layer_id], kernel_size, stride=2, padding=(kernel_size - 1) // 2),
nn.BatchNorm1d(filters[layer_id]) if norm == "batch" else nn.LayerNorm(filters[layer_id]) if norm == "layer" else nn.Identity(),
nn.ReLU() if act == "relu" else Swish() if act == "swish" else nn.Identity()
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
return x, x_len
class Conv2dSubsampling(nn.Module):
"""Conv2d Subsampling Block
Args:
num_layers: number of strided convolution layers
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_dim, in_length)
Output: (batch_size, out_dim, out_length)
"""
def __init__(self, num_layers, filters, kernel_size, norm, act):
super(Conv2dSubsampling, self).__init__()
# Assert
assert norm in ["batch", "layer", "none"]
assert act in ["relu", "swish", "none"]
# Conv 2D Subsampling Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1], filters[layer_id], kernel_size, stride=2, padding=(kernel_size - 1) // 2),
nn.BatchNorm2d(filters[layer_id]) if norm == "batch" else nn.LayerNorm(filters[layer_id]) if norm == "layer" else nn.Identity(),
nn.ReLU() if act == "relu" else Swish() if act == "swish" else nn.Identity()
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# (B, D, T) -> (B, 1, D, T)
x = x.unsqueeze(dim=1)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length)
return x, x_len
class Conv2dPoolSubsampling(nn.Module):
"""Conv2d with Max Pooling Subsampling Block
Args:
num_layers: number of strided convolution layers
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_dim, in_length)
Output: (batch_size, out_dim, out_length)
"""
def __init__(self, num_layers, filters, kernel_size, norm, act):
super(Conv2dPoolSubsampling, self).__init__()
# Assert
assert norm in ["batch", "layer", "none"]
assert act in ["relu", "swish", "none"]
# Layers
self.layers = nn.ModuleList([nn.Sequential(
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1], filters[layer_id], kernel_size, padding=(kernel_size - 1) // 2),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(filters[layer_id]) if norm == "batch" else nn.LayerNorm(filters[layer_id]) if norm == "layer" else nn.Identity(),
nn.ReLU() if act == "relu" else Swish() if act == "swish" else nn.Identity()
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# (B, D, T) -> (B, 1, D, T)
x = x.unsqueeze(dim=1)
# Layers
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length)
return x, x_len
class VGGSubsampling(nn.Module):
"""VGG style Subsampling Block
Args:
num_layers: number of strided convolution layers
filters: list of convolution layers filters
kernel_size: convolution kernel size
norm: normalization
act: activation function
Shape:
Input: (batch_size, in_dim, in_length)
Output: (batch_size, out_dim, out_length)
"""
def __init__(self, num_layers, filters, kernel_size, norm, act):
super(VGGSubsampling, self).__init__()
# Assert
assert norm in ["batch", "layer", "none"]
assert act in ["relu", "swish", "none"]
self.layers = nn.ModuleList([nn.Sequential(
# Conv 1
nn.Conv2d(1 if layer_id == 0 else filters[layer_id - 1], filters[layer_id], kernel_size, padding=(kernel_size - 1) // 2),
nn.BatchNorm2d(filters[layer_id]) if norm == "batch" else nn.LayerNorm(filters[layer_id]) if norm == "layer" else nn.Identity(),
nn.ReLU() if act == "relu" else Swish() if act == "swish" else nn.Identity(),
# Conv 2
nn.Conv2d(filters[layer_id], filters[layer_id], kernel_size, padding=(kernel_size - 1) // 2),
nn.BatchNorm2d(filters[layer_id]) if norm == "batch" else nn.LayerNorm(filters[layer_id]) if norm == "layer" else nn.Identity(),
nn.ReLU() if act == "relu" else Swish() if act == "swish" else nn.Identity(),
# Pooling
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
) for layer_id in range(num_layers)])
def forward(self, x, x_len):
# (B, D, T) -> (B, 1, D, T)
x = x.unsqueeze(dim=1)
# Stages
for layer in self.layers:
x = layer(x)
# Update Sequence Lengths
if x_len is not None:
x_len = x_len // 2
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length)
return x, x_len
###############################################################################
# Conformer Modules
###############################################################################
class FeedForwardModule(nn.Module):
"""Transformer Feed Forward Module
Args:
dim_model: model feature dimension
dim_ffn: expanded feature dimension
Pdrop: dropout probability
act: inner activation function
inner_dropout: whether to apply dropout after the inner activation function
Input: (batch size, length, dim_model)
Output: (batch size, length, dim_model)
"""
def __init__(self, dim_model, dim_ffn, Pdrop, act, inner_dropout):
super(FeedForwardModule, self).__init__()
# Assert
assert act in ["relu", "swish"]
# Layers
self.layers = nn.Sequential(
nn.LayerNorm(dim_model, eps=1e-6),
Linear(dim_model, dim_ffn),
Swish() if act=="swish" else nn.ReLU(),
nn.Dropout(p=Pdrop) if inner_dropout else nn.Identity(),
Linear(dim_ffn, dim_model),
nn.Dropout(p=Pdrop)
)
def forward(self, x):
return self.layers(x)
class MultiHeadSelfAttentionModule(nn.Module):
"""Multi-Head Self-Attention Module
Args:
dim_model: model feature dimension
num_heads: number of attention heads
Pdrop: residual dropout probability
max_pos_encoding: maximum position
relative_pos_enc: whether to use relative postion embedding
causal: True for causal attention with masked future context
group_size: Attention group size
kernel_size: Attention kernel size
stride: Query stride
linear_att: whether to use multi-head linear self-attention
"""
def __init__(self, dim_model, num_heads, Pdrop, max_pos_encoding, relative_pos_enc, causal, group_size, kernel_size, stride, linear_att):
super(MultiHeadSelfAttentionModule, self).__init__()
# Assert
assert not (group_size > 1 and kernel_size is not None), "Local grouped attention not implemented"
assert not (group_size > 1 and stride > 1 is not None), "Strided grouped attention not implemented"
assert not (linear_att and relative_pos_enc), "Linear attention requires absolute positional encodings"
# Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# Efficient Multi-Head Attention
if linear_att:
self.mhsa = MultiHeadLinearAttention(dim_model, num_heads)
# Grouped Multi-Head Self-Attention
elif group_size > 1:
if relative_pos_enc:
self.mhsa = GroupedRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding, group_size)
else:
self.mhsa = GroupedMultiHeadAttention(dim_model, num_heads, group_size)
# Local Multi-Head Self-Attention
elif kernel_size is not None and stride == 1:
if relative_pos_enc:
self.mhsa = LocalRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, kernel_size)
else:
self.mhsa = LocalMultiHeadAttention(dim_model, num_heads, kernel_size)
# Strided Multi-Head Self-Attention
elif kernel_size is None and stride > 1:
if relative_pos_enc:
self.mhsa = StridedRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding, stride)
else:
self.mhsa = StridedMultiHeadAttention(dim_model, num_heads, stride)
# Strided Local Multi-Head Self-Attention
elif stride > 1 and kernel_size is not None:
if relative_pos_enc:
self.mhsa = StridedLocalRelPosMultiHeadSelfAttention(dim_model, num_heads, causal, kernel_size, stride)
else:
self.mhsa = StridedLocalMultiHeadAttention(dim_model, num_heads, kernel_size, stride)
# Multi-Head Self-Attention
else:
if relative_pos_enc:
self.mhsa = RelPosMultiHeadSelfAttention(dim_model, num_heads, causal, max_pos_encoding)
else:
self.mhsa = MultiHeadAttention(dim_model, num_heads)
# Dropout
self.dropout = nn.Dropout(Pdrop)
# Module Params
self.rel_pos_enc = relative_pos_enc
self.linear_att = linear_att
def forward(self, x, mask=None, hidden=None):
# Pre Norm
x = self.norm(x)
# Multi-Head Self-Attention
if self.linear_att:
x, attention = self.mhsa(x, x, x)
elif self.rel_pos_enc:
x, attention, hidden = self.mhsa(x, x, x, mask, hidden)
else:
x, attention = self.mhsa(x, x, x, mask)
# Dropout
x = self.dropout(x)
return x, attention, hidden
class ConvolutionModule(nn.Module):
"""Conformer Convolution Module
Args:
dim_model: input feature dimension
dim_expand: output feature dimension
kernel_size: 1D depthwise convolution kernel size
Pdrop: residual dropout probability
stride: 1D depthwise convolution stride
padding: "valid", "same" or "causal"
Input: (batch size, input length, dim_model)
Output: (batch size, output length, dim_expand)
"""
def __init__(self, dim_model, dim_expand, kernel_size, Pdrop, stride, padding):
super(ConvolutionModule, self).__init__()
# Layers
self.layers = nn.Sequential(
nn.LayerNorm(dim_model, eps=1e-6),
Transpose(1, 2),
Conv1d(dim_model, 2 * dim_expand, kernel_size=1),
Glu(dim=1),
Conv1d(dim_expand, dim_expand, kernel_size, stride=stride, padding=padding, groups=dim_expand),
nn.BatchNorm1d(dim_expand),
Swish(),
Conv1d(dim_expand, dim_expand, kernel_size=1),
Transpose(1, 2),
nn.Dropout(p=Pdrop)
)
def forward(self, x):
return self.layers(x)
###############################################################################
# ContextNet Modules
###############################################################################
class ContextNetBlock(nn.Module):
def __init__(self, num_layers, dim_in, dim_out, kernel_size, stride, causal, se_ratio, residual, padding):
super(ContextNetBlock, self).__init__()
# Conv Layers
self.conv_layers = nn.Sequential(*[
DepthwiseSeparableConv1d(dim_in if layer_id == 0 else dim_out, dim_out, kernel_size, stride if layer_id == num_layers - 1 else 1, causal)
for layer_id in range(num_layers)])
# SE Module
self.se_module = SqueezeAndExcitationModule(dim_out, se_ratio, "swish") if se_ratio is not None else None
# Residual
self.residual = nn.Sequential(
Conv1d(dim_in, dim_out, kernel_size=1, stride=stride, groups=1, padding=padding),
nn.BatchNorm1d(dim_out)
) if residual else None
# Block Act
self.act = Swish()
def forward(self, x):
# Conv Layers
y = self.conv_layers(x)
# SE Module
if self.se_module is not None:
y = self.se_module(y)
# Residual
if self.residual is not None:
y = self.act(y + self.residual(x))
return y
class ContextNetSubsampling(nn.Module):
def __init__(self, n_mels, dim_model, kernel_size, causal):
super(ContextNetSubsampling, self).__init__()
# Blocks
self.blocks = nn.Sequential(*[ContextNetBlock(
num_layers=1 if block_id == 0 else 5,
dim_in=n_mels if block_id == 0 else dim_model,
dim_out=dim_model,
kernel_size=kernel_size,
stride=2 if block_id in [3, 7] else 1,
causal=causal,
se_ratio=None if block_id == 0 else 8,
residual=False if block_id == 0 else True,
) for block_id in range(8)])
def forward(self, x, x_len):
# Blocks
x = self.blocks(x)
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
x_len = (x_len - 1) // 2 + 1
return x, x_len
###############################################################################
# Modules
###############################################################################
class SqueezeAndExcitationModule(nn.Module):
"""Squeeze And Excitation Module
Args:
input_dim: input feature dimension
reduction_ratio: bottleneck reduction ratio
inner_act: bottleneck inner activation function
Input: (batch_size, in_dim, in_length)
Output: (batch_size, out_dim, out_length)
"""
def __init__(self, input_dim, reduction_ratio, inner_act="relu"):
super(SqueezeAndExcitationModule, self).__init__()
assert input_dim % reduction_ratio == 0
self.conv1 = Conv1d(input_dim, input_dim // reduction_ratio, kernel_size=1)
self.conv2 = Conv1d(input_dim // reduction_ratio, input_dim, kernel_size=1)
assert inner_act in ["relu", "swish"]
if inner_act == "relu":
self.inner_act = nn.ReLU()
elif inner_act == "swish":
self.inner_act = Swish()
def forward(self, x):
# Global avg Pooling
scale = x.mean(dim=-1, keepdim=True)
# (B, C, 1) -> (B, C // R, 1)
scale = self.conv1(scale)
# Inner Act
scale = self.inner_act(scale)
# (B, C // R, 1) -> (B, C, 1)
scale = self.conv2(scale)
# Sigmoid
scale = scale.sigmoid()
# Scale
x = x * scale
return x
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
class constant_learning_rate_scheduler:
def __init__(self, optimizer, lr_value):
# Model Optimizer
self.optimizer = optimizer
# Model Step
self.model_step = -1
# Scheduler Params
self.lr_value = lr_value
def step(self):
# Update Model Step
self.model_step += 1
s = self.model_step + 1
# Update LR
self.optimizer.param_groups[0]['lr'] = self.lr_value
class constant_with_decay_learning_rate_scheduler:
def __init__(self, optimizer, lr_values, decay_steps):
# Model Optimizer
self.optimizer = optimizer
# Model Step
self.model_step = -1
# Scheduler Params
self.lr_values = lr_values
self.decay_steps = decay_steps
def step(self):
# Update Model Step
self.model_step += 1
s = self.model_step + 1
# Update LR
lr_value = self.lr_values[0]
for i, step in enumerate(self.decay_steps):
if self.model_step > step:
lr_value = self.lr_values[i + 1]
else:
break
self.optimizer.param_groups[0]['lr'] = lr_value
class cosine_annealing_learning_rate_scheduler:
def __init__(self, optimizer, warmup_steps, lr_max, lr_min, end_step):
# Model Optimizer
self.optimizer = optimizer
# Model Step
self.model_step = -1
# Scheduler Params
self.warmup_steps = warmup_steps
self.lr_max = lr_max
self.lr_min = lr_min
self.end_step = end_step
def step(self):
# Update Model Step
self.model_step += 1
s = self.model_step + 1
# Compute LR
if self.model_step <= self.warmup_steps: # Warmup phase
lr = s / self.warmup_steps * self.lr_max
else: # Annealing phase
lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (self.model_step - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min
# Update LR
self.optimizer.param_groups[0]['lr'] = lr
class transformer_learning_rate_scheduler:
def __init__(self, optimizer, dim_model, warmup_steps, K):
# Model Optimizer
self.optimizer = optimizer
# Model Step
self.model_step = -1
# Scheduler Params
self.dim_model = dim_model
self.warmup_steps = warmup_steps
self.K = K
def step(self):
# Update Model Step
self.model_step += 1
s = self.model_step + 1
# Update LR
arg1 = s**-0.5
arg2 = s * (self.warmup_steps**-1.5)
self.optimizer.param_groups[0]['lr'] = self.K * self.dim_model**-0.5 * min(arg1, arg2)
class exponential_decay_transformer_learning_rate_scheduler:
def __init__(self, optimizer, warmup_steps, lr_max, alpha, end_step):
# Model Optimizer
self.optimizer = optimizer
# Model Step
self.model_step = -1
# Scheduler Params
self.warmup_steps = warmup_steps
self.lr_max = lr_max
self.alpha = alpha
self.end_step = end_step
def step(self):
# Update Model Step
self.model_step += 1
s = self.model_step + 1
# Update LR
arg1 = s / self.warmup_steps * self.lr_max # Warmup phase
arg2 = self.lr_max * self.alpha**((s - self.warmup_steps) / (self.end_step - self.warmup_steps)) # Decay phase
self.optimizer.param_groups[0]['lr'] = min(arg1, arg2)
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Base Model
from models.model import Model, init_vn
# Encoders
from models.encoders import (
ConformerEncoder
)
# Decoders
from models.decoders import (
RnnDecoder,
TransformerDecoder,
ConformerDecoder
)
# Joint Network
from models.joint_networks import (
JointNetwork
)
# Language Model
from models.lm import (
LanguageModel
)
# Losses
from models.losses import (
LossRNNT
)
# Ngram
import kenlm
class Transducer(Model):
def __init__(self, encoder_params, decoder_params, joint_params, tokenizer_params, training_params, decoding_params, name):
super(Transducer, self).__init__(tokenizer_params, training_params, decoding_params, name)
# Encoder
if encoder_params["arch"] == "Conformer":
self.encoder = ConformerEncoder(encoder_params)
else:
raise Exception("Unknown encoder architecture:", encoder_params["arch"])
# Decoder
if decoder_params["arch"] == "RNN":
self.decoder = RnnDecoder(decoder_params)
elif decoder_params["arch"] == "Transformer":
self.decoder = TransformerDecoder(decoder_params)
elif decoder_params["arch"] == "Conformer":
self.decoder = ConformerDecoder(decoder_params)
else:
raise Exception("Unknown decoder architecture:", decoder_params["arch"])
# Joint Network
self.joint_network = JointNetwork(encoder_params["dim_model"][-1] if isinstance(encoder_params["dim_model"], list) else encoder_params["dim_model"], decoder_params["dim_model"], decoder_params["vocab_size"], joint_params)
# Init VN
self.decoder.apply(lambda m: init_vn(m, training_params.get("vn_std", None)))
# Criterion
self.criterion = LossRNNT()
# Decoding
self.max_consec_dec_step = decoder_params.get("max_consec_dec_step", 5)
# Compile
self.compile(training_params)
def forward(self, batch):
# Unpack Batch
x, y, x_len, y_len = batch
# Audio Encoder (B, Taud) -> (B, T, Denc)
f, f_len, attentions = self.encoder(x, x_len)
# Add blank token
y = torch.nn.functional.pad(y, pad=(1, 0, 0, 0), value=0)
y_len = y_len + 1
# Text Decoder (B, U + 1) -> (B, U + 1, Ddec)
g, _ = self.decoder(y, None, y_len)
# Joint Network (B, T, Denc) and (B, U + 1, Ddec) -> (B, T, U + 1, V)
logits = self.joint_network(f, g)
return logits, f_len, attentions
def distribute_strategy(self, rank):
super(Transducer, self).distribute_strategy(rank)
self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder, device_ids=[self.rank])
self.decoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.decoder)
self.decoder = torch.nn.parallel.DistributedDataParallel(self.decoder, device_ids=[self.rank])
self.joint_network = torch.nn.parallel.DistributedDataParallel(self.joint_network, device_ids=[self.rank])
def parallel_strategy(self):
super(Transducer, self).parallel_strategy()
self.encoder = torch.nn.DataParallel(self.encoder)
self.decoder = torch.nn.DataParallel(self.decoder)
self.joint_network = torch.nn.DataParallel(self.joint_network)
def summary(self, show_dict=False):
print(self.name)
print("Model Parameters :", self.num_params() - self.lm.num_params() if isinstance(self.lm, LanguageModel) else self.num_params())
print(" - Encoder Parameters :", sum([p.numel() for p in self.encoder.parameters()]))
print(" - Decoder Parameters :", sum([p.numel() for p in self.decoder.parameters()]))
print(" - Joint Parameters :", sum([p.numel() for p in self.joint_network.parameters()]))
if isinstance(self.lm, LanguageModel):
print("LM Parameters :", self.lm.num_params())
if show_dict:
for key, value in self.state_dict().items():
print("{:<64} {:<16} mean {:<16.4f} std {:<16.4f}".format(key, str(tuple(value.size())), value.float().mean(), value.float().std()))
def gready_search_decoding(self, x, x_len):
# Predictions String List
preds = []
# Forward Encoder (B, Taud) -> (B, T, Denc)
f, f_len, _ = self.encoder(x, x_len)
# Batch loop
for b in range(x.size(0)): # One sample at a time for now, not batch optimized
# Init y and hidden state
y = x.new_zeros(1, 1, dtype=torch.long)
hidden = None
enc_step = 0
consec_dec_step = 0
# Decoder loop
while enc_step < f_len[b]:
# Forward Decoder (1, 1) -> (1, 1, Ddec)
g, hidden = self.decoder(y[:, -1:], hidden)
# Joint Network loop
while enc_step < f_len[b]:
# Forward Joint Network (1, 1, Denc) and (1, 1, Ddec) -> (1, V)
logits = self.joint_network(f[b:b+1, enc_step], g[:, 0])
# Token Prediction
pred = logits.softmax(dim=-1).log().argmax(dim=-1) # (1)
# Null token or max_consec_dec_step
if pred == 0 or consec_dec_step == self.max_consec_dec_step:
consec_dec_step = 0
enc_step += 1
# Token
else:
consec_dec_step += 1
y = torch.cat([y, pred.unsqueeze(0)], axis=-1)
break
# Decode Label Sequence
pred = self.tokenizer.decode(y[:, 1:].tolist())
preds += pred
return preds
def beam_search_decoding(self, x, x_len, beam_size=None):
# Overwrite beam size
if beam_size is None:
beam_size = self.beam_size
# Load ngram lm
ngram_lm = None
if self.ngram_path is not None:
try:
ngram_lm = kenlm.Model(self.ngram_path)
except:
print("Ngram language model not found...")
# Predictions String List
batch_predictions = []
# Forward Encoder (B, Taud) -> (B, T, Denc)
f, f_len, _ = self.encoder(x, x_len)
# Batch loop
for b in range(x.size(0)):
# Decoder Input
y = torch.ones((1, 1), device=x.device, dtype=torch.long)
# Default Beam hypothesis
B_hyps = [{
"prediction": [0],
"logp_score": 0.0,
"hidden_state": None,
"hidden_state_lm": None,
}]
# Init Ngram LM State
if ngram_lm and self.ngram_alpha > 0:
state1 = kenlm.State()
state2 = kenlm.State()
ngram_lm.NullContextWrite(state1)
B_hyps[0].update({"ngram_lm_state1": state1, "ngram_lm_state2": state2})
# Encoder loop
for enc_step in range(f_len[b]):
A_hyps = B_hyps
B_hyps = []
# While B contains less than W hypothesis
while len(B_hyps) < beam_size:
# A most probable hyp
A_best_hyp = max(A_hyps, key=lambda x: x["logp_score"] / len(x["prediction"]))
# Remove best hyp from A
A_hyps.remove(A_best_hyp)
# Forward Decoder (1, 1) -> (1, 1, Ddec)
y[0, 0] = A_best_hyp["prediction"][-1]
g, hidden = self.decoder(y, A_best_hyp["hidden_state"])
g = g[:, 0] # (1, Ddec)
# Forward Joint Network (1, Denc) and (1, Ddec) -> (1, V)
logits = self.joint_network(f[b:b+1, enc_step], g)
logits = logits[0] # (V)
# Apply Temperature
logits = logits / self.tmp
# Compute logP
logP = logits.softmax(dim=-1).log()
# LM Prediction
if self.lm and self.lm_weight:
# Forward LM
logits_lm, hidden_lm = self.lm.decode(y, A_best_hyp["hidden_state_lm"]) # (1, 1, V)
logits_lm = logits_lm[0, 0] # (V)
# Apply Temperature
logits_lm = logits_lm / self.lm_tmp
# Compute logP
logP_lm = logits_lm.softmax(dim=-1).log()
# Add LogP
logP += self.lm_weight * logP_lm
# Sorted top k logp and their labels
topk_logP, topk_labels = torch.topk(logP, k=beam_size, dim=-1)
# Extend hyp by selection
for j in range(topk_logP.size(0)):
# Updated hyp with logp
hyp = {
"prediction": A_best_hyp["prediction"][:],
"logp_score": A_best_hyp["logp_score"] + topk_logP[j],
"hidden_state": A_best_hyp["hidden_state"],
}
# Blank Prediction -> Append hyp to B
if topk_labels[j] == 0:
if self.lm and self.lm_weight > 0:
hyp["hidden_state_lm"] = A_best_hyp["hidden_state_lm"]
if ngram_lm and self.ngram_alpha > 0:
hyp["ngram_lm_state1"] = A_best_hyp["ngram_lm_state1"].__deepcopy__()
hyp["ngram_lm_state2"] = A_best_hyp["ngram_lm_state2"].__deepcopy__()
B_hyps.append(hyp)
# Non Blank Prediction -> Update hyp hidden / prediction and append to A
else:
hyp["prediction"].append(topk_labels[j].item())
hyp["hidden_state"] = hidden
if self.lm and self.lm_weight > 0:
hyp["hidden_state_lm"] = hidden_lm
# Ngram LM Rescoring
if ngram_lm and self.ngram_alpha > 0:
state1 = A_best_hyp["ngram_lm_state1"].__deepcopy__()
state2 = A_best_hyp["ngram_lm_state2"].__deepcopy__()
s = chr(topk_labels[j].item() + self.ngram_offset)
lm_score = ngram_lm.BaseScore(state1, s, state2)
hyp["logp_score"] += self.ngram_alpha * lm_score + self.ngram_beta
hyp["ngram_lm_state1"] = state2
hyp["ngram_lm_state2"] = state1
A_hyps.append(hyp)
# Pick best hyp
best_hyp = max(B_hyps, key=lambda x: x["logp_score"] / len(x["prediction"]))
# Decode hyp
batch_predictions.append(self.tokenizer.decode(best_hyp["prediction"][1:]))
return batch_predictions
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment