Commit d75c2dc7 authored by burchim's avatar burchim
Browse files

Efficient Conformer

parents
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
# Models
from models.transducer import Transducer
from models.model_ctc import ModelCTC, InterCTC
from models.lm import LanguageModel
# Datasets
from utils.datasets import (
LibriSpeechDataset,
LibriSpeechCorpusDataset
)
# Preprocessing
from utils.preprocessing import (
collate_fn_pad
)
def create_model(config):
# Create Model
if config["model_type"] == "Transducer":
model = Transducer(
encoder_params=config["encoder_params"],
decoder_params=config["decoder_params"],
joint_params=config["joint_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "CTC":
model = ModelCTC(
encoder_params=config["encoder_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "InterCTC":
model = InterCTC(
encoder_params=config["encoder_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
elif config["model_type"] == "LM":
model = LanguageModel(
lm_params=config["lm_params"],
tokenizer_params=config["tokenizer_params"],
training_params=config["training_params"],
decoding_params=config["decoding_params"],
name=config["model_name"]
)
else:
raise Exception("Unknown model type")
return model
def load_datasets(training_params, tokenizer_params, args):
# Training Datasets
training_datasets = {
"LibriSpeech": {
"class": LibriSpeechDataset,
"split": {
"training": "train",
"training-clean": "train-clean",
"validation-clean": None,
"validation-other": None,
"test-clean": None,
"test-other": None,
"eval_time": None,
"eval_time_encoder": None,
"eval_time_decoder": None,
}
},
"LibriSpeechCorpus": {
"class": LibriSpeechCorpusDataset,
"split": {
"training": "train",
"validation-clean": None,
"validation-other": None,
"test-clean": None,
"test-other": None,
"eval_time": None,
"eval_time_encoder": None,
"eval_time_decoder": None,
}
}
}
# Evaluation Datasets
evaluation_datasets = {
"LibriSpeech": {
"class": LibriSpeechDataset,
"split": {
"training": ["dev-clean", "dev-other"],
"training-clean": ["dev-clean", "dev-other"],
"validation-clean": "dev-clean",
"validation-other": "dev-other",
"test-clean": "test-clean",
"test-other": "test-other",
"eval_time": "dev-clean",
"eval_time_encoder": "dev-clean",
"eval_time_decoder": "dev-clean",
}
},
"LibriSpeechCorpus": {
"class": LibriSpeechCorpusDataset,
"split": {
"training": "val",
"validation-clean": "val",
"validation-other": "val",
"test-clean": "test",
"test-other": "test",
"eval_time": "val",
"eval_time_encoder": "val",
"eval_time_decoder": "val",
}
}
}
# Select Dataset and Split
training_dataset = training_datasets[training_params["training_dataset"]]["class"]
training_split = training_datasets[training_params["training_dataset"]]["split"][args.mode]
evaluation_dataset = evaluation_datasets[training_params["evaluation_dataset"]]["class"]
evaluation_split = evaluation_datasets[training_params["evaluation_dataset"]]["split"][args.mode]
# Training Dataset
if training_split:
if args.rank == 0:
print("Loading training dataset : {} {}".format(training_params["training_dataset"], training_split))
dataset_train = training_dataset(training_params["training_dataset_path"], training_params, tokenizer_params, training_split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset_train = torch.utils.data.DataLoader(dataset_train, batch_size=training_params["batch_size"], shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, drop_last=True, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset_train.dataset.__len__(), "samples", "/", dataset_train.__len__(), "batches")
else:
dataset_train = None
# Evaluation Dataset
if evaluation_split:
# Multiple Evaluation datasets
if isinstance(evaluation_split, list):
dataset_eval = {}
for split in evaluation_split:
if args.rank == 0:
print("Loading evaluation dataset : {} {}".format(training_params["evaluation_dataset"], split))
dataset = evaluation_dataset(training_params["evaluation_dataset_path"], training_params, tokenizer_params, split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size_eval, shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset.dataset.__len__(), "samples", "/", dataset.__len__(), "batches")
dataset_eval[split] = dataset
# One Evaluation dataset
else:
if args.rank == 0:
print("Loading evaluation dataset : {} {}".format(training_params["evaluation_dataset"], evaluation_split))
dataset_eval = evaluation_dataset(training_params["evaluation_dataset_path"], training_params, tokenizer_params, evaluation_split, args)
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset_eval, num_replicas=args.world_size,rank=args.rank)
else:
sampler = None
dataset_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=args.batch_size_eval, shuffle=(not args.distributed), num_workers=args.num_workers, collate_fn=collate_fn_pad, sampler=sampler, pin_memory=False)
if args.rank == 0:
print("Loaded :", dataset_eval.dataset.__len__(), "samples", "/", dataset_eval.__len__(), "batches")
else:
dataset_eval = None
return dataset_train, dataset_eval
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pytorch
import torch
# Functions and Utils
from functions import *
from utils.preprocessing import *
# Other
import json
import argparse
import os
def main(rank, args):
# Process rank
args.rank = rank
# Distributed Computing
if args.distributed:
torch.cuda.set_device(args.rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank)
# Load Config
with open(args.config_file) as json_config:
config = json.load(json_config)
# Device
device = torch.device("cuda:" + str(args.rank) if torch.cuda.is_available() and not args.cpu else "cpu")
print("Device:", device)
# Create Tokenizer
if args.create_tokenizer:
if args.rank == 0:
print("Creating Tokenizer")
create_tokenizer(config["training_params"], config["tokenizer_params"])
if args.distributed:
torch.distributed.barrier()
# Create Model
model = create_model(config).to(device)
# Load Model
if args.initial_epoch is not None:
model.load(config["training_params"]["callback_path"] + "checkpoints_" + str(args.initial_epoch) + ".ckpt")
else:
args.initial_epoch = 0
# Load Encoder Only
if args.initial_epoch_encoder is not None:
model.load_encoder(config["training_params"]["callback_path_encoder"] + "checkpoints_" + str(args.initial_epoch_encoder) + ".ckpt")
# Load LM
if args.initial_epoch_lm:
# Load LM Config
with open(config["decoding_params"]["lm_config"]) as json_config:
config_lm = json.load(json_config)
# Create LM
model.lm = create_model(config_lm).to(device)
# Load LM
model.lm.load(config_lm["training_params"]["callback_path"] + "checkpoints_" + str(args.initial_epoch_lm) + ".ckpt")
# Model Summary
if args.rank == 0:
model.summary(show_dict=args.show_dict)
# Distribute Strategy
if args.distributed:
if args.rank == 0:
print("Parallelize model on", args.world_size, "GPUs")
model.distribute_strategy(args.rank)
# Parallel Strategy
if args.parallel and not args.distributed:
print("Parallelize model on", torch.cuda.device_count(), "GPUs")
model.parallel_strategy()
# Prepare Dataset
if args.prepare_dataset:
if args.rank == 0:
print("Preparing dataset")
prepare_dataset(config["training_params"], config["tokenizer_params"], model.tokenizer)
if args.distributed:
torch.distributed.barrier()
# Load Dataset
dataset_train, dataset_val = load_datasets(config["training_params"], config["tokenizer_params"], args)
###############################################################################
# Modes
###############################################################################
# Stochastic Weight Averaging
if args.swa:
model.swa(dataset_train, callback_path=config["training_params"]["callback_path"], start_epoch=args.swa_epochs[0] if args.swa_epochs else None, end_epoch=args.swa_epochs[1] if args.swa_epochs else None, epochs_list=args.swa_epochs_list, update_steps=args.steps_per_epoch, swa_type=args.swa_type)
# Training
elif args.mode.split("-")[0] == "training":
model.fit(dataset_train,
config["training_params"]["epochs"],
dataset_val=dataset_val,
val_steps=args.val_steps,
verbose_val=args.verbose_val,
initial_epoch=int(args.initial_epoch),
callback_path=config["training_params"]["callback_path"],
steps_per_epoch=args.steps_per_epoch,
mixed_precision=config["training_params"]["mixed_precision"],
accumulated_steps=config["training_params"]["accumulated_steps"],
saving_period=args.saving_period,
val_period=args.val_period)
# Evaluation
elif args.mode.split("-")[0] == "validation" or args.mode.split("-")[0] == "test":
# Gready Search Evaluation
if args.gready or model.beam_size is None:
if args.rank == 0:
print("Gready Search Evaluation")
wer, _, _, _ = model.evaluate(dataset_val, eval_steps=args.val_steps, verbose=args.verbose_val, beam_size=1, eval_loss=args.eval_loss)
if args.rank == 0:
print("Geady Search WER : {:.2f}%".format(100 * wer))
# Beam Search Evaluation
else:
if args.rank == 0:
print("Beam Search Evaluation")
wer, _, _, _ = model.evaluate(dataset_val, eval_steps=args.val_steps, verbose=args.verbose_val, beam_size=model.beam_size, eval_loss=False)
if args.rank == 0:
print("Beam Search WER : {:.2f}%".format(100 * wer))
# Eval Time
elif args.mode.split("-")[0] == "eval_time":
print("Model Eval Time")
inf_time = model.eval_time(dataset_val, eval_steps=args.val_steps, beam_size=1, rnnt_max_consec_dec_steps=args.rnnt_max_consec_dec_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(inf_time))
elif args.mode.split("-")[0] == "eval_time_encoder":
print("Encoder Eval Time")
enc_time = model.eval_time_encoder(dataset_val, eval_steps=args.val_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(enc_time))
elif args.mode.split("-")[0] == "eval_time_decoder":
print("Decoder Eval Time")
dec_time = model.eval_time_decoder(dataset_val, eval_steps=args.val_steps, profiler=args.profiler)
print("eval time : {:.2f}s".format(dec_time))
# Destroy Process Group
if args.distributed:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
# Args
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config_file", type=str, default="configs/EfficientConformerCTCSmall.json", help="Json configuration file containing model hyperparameters")
parser.add_argument("-m", "--mode", type=str, default="training", help="Mode : training, validation-clean, test-clean, eval_time-dev-clean, ...")
parser.add_argument("-d", "--distributed", action="store_true", help="Distributed data parallelization")
parser.add_argument("-i", "--initial_epoch", type=str, default=None, help="Load model from checkpoint")
parser.add_argument("--initial_epoch_lm", type=str, default=None, help="Load language model from checkpoint")
parser.add_argument("--initial_epoch_encoder", type=str, default=None, help="Load model encoder from encoder checkpoint")
parser.add_argument("-p", "--prepare_dataset", action="store_true", help="Prepare dataset for training")
parser.add_argument("-j", "--num_workers", type=int, default=8, help="Number of data loading workers")
parser.add_argument("--create_tokenizer", action="store_true", help="Create model tokenizer")
parser.add_argument("--batch_size_eval", type=int, default=8, help="Evaluation batch size")
parser.add_argument("--verbose_val", action="store_true", help="Evaluation verbose")
parser.add_argument("--val_steps", type=int, default=None, help="Number of validation steps")
parser.add_argument("--steps_per_epoch", type=int, default=None, help="Number of steps per epoch")
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count(), help="Number of available GPUs")
parser.add_argument("--cpu", action="store_true", help="Load model on cpu")
parser.add_argument("--show_dict", action="store_true", help="Show model dict summary")
parser.add_argument("--swa", action="store_true", help="Stochastic weight averaging")
parser.add_argument("--swa_epochs", nargs="+", default=None, help="Start epoch / end epoch for swa")
parser.add_argument("--swa_epochs_list", nargs="+", default=None, help="List of checkpoints epochs for swa")
parser.add_argument("--swa_type", type=str, default="equal", help="Stochastic weight averaging type (equal/exp)")
parser.add_argument("--parallel", action="store_true", help="Parallelize model using data parallelization")
parser.add_argument("--rnnt_max_consec_dec_steps", type=int, default=None, help="Number of maximum consecutive transducer decoder steps during inference")
parser.add_argument("--eval_loss", action="store_true", help="Compute evaluation loss during evaluation")
parser.add_argument("--gready", action="store_true", help="Proceed to a gready search evaluation")
parser.add_argument("--saving_period", type=int, default=1, help="Model saving every 'n' epochs")
parser.add_argument("--val_period", type=int, default=1, help="Model validation every 'n' epochs")
parser.add_argument("--profiler", action="store_true", help="Enable eval time profiler")
# Parse Args
args = parser.parse_args()
# Run main
if args.distributed:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
torch.multiprocessing.spawn(main, nprocs=args.world_size, args=(args,))
else:
main(0, args)
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
###############################################################################
# Activation Functions
###############################################################################
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * x.sigmoid()
class Glu(nn.Module):
def __init__(self, dim):
super(Glu, self).__init__()
self.dim = dim
def forward(self, x):
x_in, x_gate = x.chunk(2, dim=self.dim)
return x_in * x_gate.sigmoid()
\ No newline at end of file
This diff is collapsed.
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Modules
from models.modules import (
FeedForwardModule,
MultiHeadSelfAttentionModule,
ConvolutionModule
)
# Layers
from models.layers import (
Conv1d,
Transpose
)
class ConformerBlock(nn.Module):
def __init__(
self,
dim_model,
dim_expand,
ff_ratio,
num_heads,
kernel_size,
att_group_size,
att_kernel_size,
linear_att,
Pdrop,
relative_pos_enc,
max_pos_encoding,
conv_stride,
att_stride,
causal
):
super(ConformerBlock, self).__init__()
# Feed Forward Module 1
self.feed_forward_module1 = FeedForwardModule(
dim_model=dim_model,
dim_ffn=dim_model * ff_ratio,
Pdrop=Pdrop,
act="swish",
inner_dropout=True
)
# Multi-Head Self-Attention Module
self.multi_head_self_attention_module = MultiHeadSelfAttentionModule(
dim_model=dim_model,
num_heads=num_heads,
Pdrop=Pdrop,
max_pos_encoding=max_pos_encoding,
relative_pos_enc=relative_pos_enc,
causal=causal,
group_size=att_group_size,
kernel_size=att_kernel_size,
stride=att_stride,
linear_att=linear_att
)
# Convolution Module
self.convolution_module = ConvolutionModule(
dim_model=dim_model,
dim_expand=dim_expand,
kernel_size=kernel_size,
Pdrop=Pdrop,
stride=conv_stride,
padding="causal" if causal else "same"
)
# Feed Forward Module 2
self.feed_forward_module2 = FeedForwardModule(
dim_model=dim_expand,
dim_ffn=dim_expand * ff_ratio,
Pdrop=Pdrop,
act="swish",
inner_dropout=True
)
# Block Norm
self.norm = nn.LayerNorm(dim_expand, eps=1e-6)
# Attention Residual
self.att_res = nn.Sequential(
Transpose(1, 2),
nn.MaxPool1d(kernel_size=1, stride=att_stride),
Transpose(1, 2)
) if att_stride > 1 else nn.Identity()
# Convolution Residual
self.conv_res = nn.Sequential(
Transpose(1, 2),
Conv1d(dim_model, dim_expand, kernel_size=1, stride=conv_stride),
Transpose(1, 2)
) if dim_model != dim_expand else nn.Sequential(
Transpose(1, 2),
nn.MaxPool1d(kernel_size=1, stride=conv_stride),
Transpose(1, 2)
) if conv_stride > 1 else nn.Identity()
# Bloc Stride
self.stride = conv_stride * att_stride
def forward(self, x, mask=None, hidden=None):
# FFN Module 1
x = x + 1/2 * self.feed_forward_module1(x)
# MHSA Module
x_att, attention, hidden = self.multi_head_self_attention_module(x, mask, hidden)
x = self.att_res(x) + x_att
# Conv Module
x = self.conv_res(x) + self.convolution_module(x)
# FFN Module 2
x = x + 1/2 * self.feed_forward_module2(x)
# Block Norm
x = self.norm(x)
return x, attention, hidden
class TransformerBlock(nn.Module):
def __init__(self, dim_model, ff_ratio, num_heads, Pdrop, max_pos_encoding, relative_pos_enc, causal):
super(TransformerBlock, self).__init__()
# Muti-Head Self-Attention Module
self.multi_head_self_attention_module = MultiHeadSelfAttentionModule(
dim_model=dim_model,
num_heads=num_heads,
Pdrop=Pdrop,
max_pos_encoding=max_pos_encoding,
relative_pos_enc=relative_pos_enc,
causal=causal,
group_size=1,
kernel_size=1,
stride=1,
efficient_att=False
)
# Feed Forward Module
self.feed_forward_module = FeedForwardModule(
dim_model=dim_model,
dim_ffn=dim_model * ff_ratio,
Pdrop=Pdrop,
act="relu",
inner_dropout=False
)
def forward(self, x, mask=None, hidden=None):
# Muti-Head Self-Attention Module
x_att, attention, hidden = self.multi_head_self_attention_module(x, mask, hidden)
x = x + x_att
# Feed Forward Module
x = x + self.feed_forward_module(x)
return x, attention, hidden
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Positional Encodings and Masks
from models.attentions import (
SinusoidalPositionalEncoding,
StreamingMask
)
# Blocks
from models.blocks import (
TransformerBlock,
ConformerBlock
)
# Layers
from models.layers import (
Embedding,
LSTM
)
###############################################################################
# Decoder Models
###############################################################################
class RnnDecoder(nn.Module):
def __init__(self, params):
super(RnnDecoder, self).__init__()
self.embedding = Embedding(params["vocab_size"], params["dim_model"], padding_idx=0)
self.rnn = LSTM(input_size=params["dim_model"], hidden_size=params["dim_model"], num_layers=params["num_layers"], batch_first=True, bidirectional=False)
def forward(self, y, hidden, y_len=None):
# Sequence Embedding (B, U + 1) -> (N, U + 1, D)
y = self.embedding(y)
# Pack padded batch sequences
if y_len is not None:
y = nn.utils.rnn.pack_padded_sequence(y, y_len.cpu(), batch_first=True, enforce_sorted=False)
# Hidden state provided
if hidden is not None:
y, hidden = self.rnn(y, hidden)
# None hidden state
else:
y, hidden = self.rnn(y)
# Pad packed batch sequences
if y_len is not None:
y, _ = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
# return last layer steps outputs and every layer last step hidden state
return y, hidden
class TransformerDecoder(nn.Module):
def __init__(self, params):
super(TransformerDecoder, self).__init__()
# Look Ahead Mask
self.look_ahead_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0)
# Embedding Layer
self.embedding = nn.Embedding(params["vocab_size"], params["dim_model"], padding_idx=0)
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"])
# Transformer Blocks
self.blocks = nn.ModuleList([TransformerBlock(
dim_model=params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"],
Pdrop=params["Pdrop"],
max_pos_encoding=params["max_pos_encoding"],
relative_pos_enc=params["relative_pos_enc"],
causal=True
) for block_id in range(params["num_blocks"])])
def forward(self, y, hidden, y_len=None):
# Look Ahead Mask
if hidden == None:
mask = self.look_ahead_mask(y, y_len)
else:
mask = None
# Linear Proj
y = self.embedding(y)
# Dropout
y = self.dropout(y)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
y = y + self.pos_enc(y.size(0), y.size(1))
# Transformer Blocks
attentions = []
hidden_new = []
for block_id, block in enumerate(self.blocks):
# Hidden State Provided
if hidden is not None:
y, attention, block_hidden = block(y, mask, hidden[block_id])
else:
y, attention, block_hidden = block(y, mask)
# Update Hidden / Att Maps
if not self.training:
attentions.append(attention)
hidden_new.append(block_hidden)
return y, hidden_new
class ConformerDecoder(nn.Module):
def __init__(self, params):
super(ConformerDecoder, self).__init__()
# Look Ahead Mask
self.look_ahead_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0)
# Embedding Layer
self.embedding = nn.Embedding(params["vocab_size"], params["dim"], padding_idx=0)
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"])
# Conformer Layers
self.blocks = nn.ModuleList([ConformerBlock(
dim_model=params["dim_model"],
dim_expand=params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"],
kernel_size=params["kernel_size"],
att_group_size=1,
att_kernel_size=None,
Pdrop=params["Pdrop"],
relative_pos_enc=params["relative_pos_enc"],
max_pos_encoding=params["max_pos_encoding"],
conv_stride=1,
att_stride=1,
causal=True
) for block_id in range(params["num_blocks"])])
def forward(self, y, hidden, y_len=None):
# Hidden state provided
if hidden is not None:
y = torch.cat([hidden, y], axis=1)
# Look Ahead Mask
mask = self.look_ahead_mask(y, y_len)
# Update Hidden
hidden_new = y
# Linear Proj
y = self.embedding(y)
# Dropout
y = self.dropout(y)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
y = y + self.pos_enc(y.size(0), y.size(1))
# Transformer Blocks
attentions = []
for block in self.blocks:
y, attention = block(y, mask)
attentions.append(attention)
if hidden is not None:
y = y[:, -1:]
return y, hidden_new
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Blocks
from models.blocks import (
ConformerBlock
)
# Modules
from models.modules import (
AudioPreprocessing,
SpecAugment,
Conv1dSubsampling,
Conv2dSubsampling,
Conv2dPoolSubsampling,
VGGSubsampling
)
# Positional Encodings and Masks
from models.attentions import (
SinusoidalPositionalEncoding,
StreamingMask
)
###############################################################################
# Encoder Models
###############################################################################
class ConformerEncoder(nn.Module):
def __init__(self, params):
super(ConformerEncoder, self).__init__()
# Audio Preprocessing
self.preprocessing = AudioPreprocessing(params["sample_rate"], params["n_fft"], params["win_length_ms"], params["hop_length_ms"], params["n_mels"], params["normalize"], params["mean"], params["std"])
# Spec Augment
self.augment = SpecAugment(params["spec_augment"], params["mF"], params["F"], params["mT"], params["pS"])
# Subsampling Module
if params["subsampling_module"] == "Conv1d":
self.subsampling_module = Conv1dSubsampling(params["subsampling_layers"], params["n_mels"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "Conv2d":
self.subsampling_module = Conv2dSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "Conv2dPool":
self.subsampling_module = Conv2dPoolSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
elif params["subsampling_module"] == "VGG":
self.subsampling_module = VGGSubsampling(params["subsampling_layers"], params["subsampling_filters"], params["subsampling_kernel_size"], params["subsampling_norm"], params["subsampling_act"])
else:
raise Exception("Unknown subsampling module:", params["subsampling_module"])
# Padding Mask
self.padding_mask = StreamingMask(left_context=params.get("left_context", params["max_pos_encoding"]), right_context=0 if params.get("causal", False) else params.get("right_context", params["max_pos_encoding"]))
# Linear Proj
self.linear = nn.Linear(params["subsampling_filters"][-1] * params["n_mels"] // 2**params["subsampling_layers"], params["dim_model"][0] if isinstance(params["dim_model"], list) else params["dim_model"])
# Dropout
self.dropout = nn.Dropout(p=params["Pdrop"])
# Sinusoidal Positional Encodings
self.pos_enc = None if params["relative_pos_enc"] else SinusoidalPositionalEncoding(params["max_pos_encoding"], params["dim_model"][0] if isinstance(params["dim_model"], list) else params["dim_model"])
# Conformer Blocks
self.blocks = nn.ModuleList([ConformerBlock(
dim_model=params["dim_model"][(block_id > torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
dim_expand=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
ff_ratio=params["ff_ratio"],
num_heads=params["num_heads"][(block_id > torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["num_heads"], list) else params["num_heads"],
kernel_size=params["kernel_size"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["kernel_size"], list) else params["kernel_size"],
att_group_size=params["att_group_size"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params.get("att_group_size", 1), list) else params.get("att_group_size", 1),
att_kernel_size=params["att_kernel_size"][(block_id > torch.tensor(params.get("strided_layers", []))).sum()] if isinstance(params.get("att_kernel_size", None), list) else params.get("att_kernel_size", None),
linear_att=params.get("linear_att", False),
Pdrop=params["Pdrop"],
relative_pos_enc=params["relative_pos_enc"],
max_pos_encoding=params["max_pos_encoding"] // params.get("stride", 2)**int((block_id > torch.tensor(params.get("strided_blocks", []))).sum()),
conv_stride=(params["conv_stride"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params["conv_stride"], list) else params["conv_stride"]) if block_id in params.get("strided_blocks", []) else 1,
att_stride=(params["att_stride"][(block_id > torch.tensor(params.get("strided_blocks", []))).sum()] if isinstance(params["att_stride"], list) else params["att_stride"]) if block_id in params.get("strided_blocks", []) else 1,
causal=params.get("causal", False)
) for block_id in range(params["num_blocks"])])
def forward(self, x, x_len=None):
# Audio Preprocessing
x, x_len = self.preprocessing(x, x_len)
# Spec Augment
if self.training:
x = self.augment(x, x_len)
# Subsampling Module
x, x_len = self.subsampling_module(x, x_len)
# Padding Mask
mask = self.padding_mask(x, x_len)
# Transpose (B, D, T) -> (B, T, D)
x = x.transpose(1, 2)
# Linear Projection
x = self.linear(x)
# Dropout
x = self.dropout(x)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
x = x + self.pos_enc(x.size(0), x.size(1))
# Conformer Blocks
attentions = []
for block in self.blocks:
x, attention, hidden = block(x, mask)
attentions.append(attention)
# Strided Block
if block.stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if mask is not None:
mask = mask[:, :, ::block.stride, ::block.stride]
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
return x, x_len, attentions
class ConformerEncoderInterCTC(ConformerEncoder):
def __init__(self, params):
super(ConformerEncoderInterCTC, self).__init__(params)
# Inter CTC blocks
self.interctc_blocks = params["interctc_blocks"]
for block_id in params["interctc_blocks"]:
self.__setattr__(
name="linear_expand_" + str(block_id),
value=nn.Linear(
in_features=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"],
out_features=params["vocab_size"]))
self.__setattr__(
name="linear_proj_" + str(block_id),
value=nn.Linear(
in_features=params["vocab_size"],
out_features=params["dim_model"][(block_id >= torch.tensor(params.get("expand_blocks", []))).sum()] if isinstance(params["dim_model"], list) else params["dim_model"]))
def forward(self, x, x_len=None):
# Audio Preprocessing
x, x_len = self.preprocessing(x, x_len)
# Spec Augment
if self.training:
x = self.augment(x, x_len)
# Subsampling Module
x, x_len = self.subsampling_module(x, x_len)
# Padding Mask
mask = self.padding_mask(x, x_len)
# Transpose (B, D, T) -> (B, T, D)
x = x.transpose(1, 2)
# Linear Projection
x = self.linear(x)
# Dropout
x = self.dropout(x)
# Sinusoidal Positional Encodings
if self.pos_enc is not None:
x = x + self.pos_enc(x.size(0), x.size(1))
# Conformer Blocks
attentions = []
interctc_probs = []
for block_id, block in enumerate(self.blocks):
x, attention, hidden = block(x, mask)
attentions.append(attention)
# Strided Block
if block.stride > 1:
# Stride Mask (B, 1, T // S, T // S)
if mask is not None:
mask = mask[:, :, ::block.stride, ::block.stride]
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
# Inter CTC Block
if block_id in self.interctc_blocks:
interctc_prob = self.__getattr__("linear_expand_" + str(block_id))(x).softmax(dim=-1)
interctc_probs.append(interctc_prob)
x = x + self.__getattr__("linear_proj_" + str(block_id))(interctc_prob)
return x, x_len, attentions, interctc_probs
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Layers
from models.layers import (
Linear
)
# Activations Functions
from models.activations import (
Swish
)
###############################################################################
# Joint Networks
###############################################################################
class JointNetwork(nn.Module):
def __init__(self, dim_encoder, dim_decoder, vocab_size, params):
super(JointNetwork, self).__init__()
assert params["act"] in ["tanh", "relu", "swish", None]
assert params["joint_mode"] in ["concat", "sum"]
# Model layers
if params["dim_model"] is not None:
# Linear Layers
self.linear_encoder = Linear(dim_encoder, params["dim_model"])
self.linear_decoder = Linear(dim_decoder, params["dim_model"])
# Joint Mode
if params["joint_mode"] == "concat":
self.joint_mode = "concat"
self.linear_joint = Linear(2 * params["dim_model"], vocab_size)
elif params["joint_mode"] == "sum":
self.joint_mode = 'sum'
self.linear_joint = Linear(params["dim_model"], vocab_size)
else:
# Linear Layers
self.linear_encoder = nn.Identity()
self.linear_decoder = nn.Identity()
# Joint Mode
if params["joint_mode"] == "concat":
self.joint_mode = "concat"
self.linear_joint = Linear(dim_encoder + dim_decoder, vocab_size)
elif params["joint_mode"] == "sum":
assert dim_encoder == dim_decoder
self.joint_mode = 'sum'
self.linear_joint = Linear(dim_encoder, vocab_size)
# Model Act Function
if params["act"] == "tanh":
self.act = nn.Tanh()
elif params["act"] == "relu":
self.act = nn.ReLU()
elif params["act"] == "swish":
self.act = Swish()
else:
self.act = nn.Identity()
def forward(self, f, g):
f = self.linear_encoder(f)
g = self.linear_decoder(g)
# Training or Eval Loss
if self.training or (len(f.size()) == 3 and len(g.size()) == 3):
f = f.unsqueeze(2) # (B, T, 1, D)
g = g.unsqueeze(1) # (B, 1, U + 1, D)
f = f.repeat([1, 1, g.size(2), 1]) # (B, T, U + 1, D)
g = g.repeat([1, f.size(1), 1, 1]) # (B, T, U + 1, D)
# Joint Encoder and Decoder
if self.joint_mode == "concat":
joint = torch.cat([f, g], dim=-1) # Training : (B, T, U + 1, 2D) / Decoding : (B, 2D)
elif self.joint_mode == "sum":
joint = f + g # Training : (B, T, U + 1, D) / Decoding : (B, D)
# Act Function
joint = self.act(joint)
# Output Linear Projection
outputs = self.linear_joint(joint) # Training : (B, T, U + 1, V) / Decoding : (B, V)
return outputs
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._VF as _VF
from torch.nn.modules.utils import _single, _pair
# Activation Functions
from models.activations import (
Swish
)
###############################################################################
# Layers
###############################################################################
class Linear(nn.Linear):
def __init__(self, in_features, out_features, bias = True):
super(Linear, self).__init__(
in_features=in_features,
out_features=out_features,
bias=bias)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
return F.linear(input, weight, self.bias)
class Conv1d(nn.Conv1d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride = 1,
padding = "same",
dilation = 1,
groups = 1,
bias = True
):
super(Conv1d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros")
# Assert
assert padding in ["valid", "same", "causal"]
# Padding
if padding == "valid":
self.pre_padding = None
elif padding == "same":
self.pre_padding = nn.ConstantPad1d(padding=((kernel_size - 1) // 2, (kernel_size - 1) // 2), value=0)
elif padding == "causal":
self.pre_padding = nn.ConstantPad1d(padding=(kernel_size - 1, 0), value=0)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Padding
if self.pre_padding is not None:
input = self.pre_padding(input)
# Apply Weight
return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = True, padding_mode = 'zeros'):
super(Conv2d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, self.bias, self.stride, _pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class LSTM(nn.LSTM):
def __init__(self, input_size, hidden_size, num_layers, batch_first, bidirectional):
super(LSTM, self).__init__(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
bidirectional=bidirectional)
# Variational Noise
self.noises = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noises = []
for i in range(0, len(self._flat_weights), 4):
self.noises.append(torch.normal(mean=0.0, std=1.0, size=self._flat_weights[i].size(), device=self._flat_weights[i].device, dtype=self._flat_weights[i].dtype))
self.noises.append(torch.normal(mean=0.0, std=1.0, size=self._flat_weights[i+1].size(), device=self._flat_weights[i+1].device, dtype=self._flat_weights[i+1].dtype))
# Broadcast Noise
if distributed:
for noise in self.noises:
torch.distributed.broadcast(noise, 0)
def forward(self, input, hx=None): # noqa: F811
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, nn.utils.rnn.PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
# Add Noise
if self.noises is not None and self.training:
weight = []
for i in range(0, len(self.noises), 2):
weight.append(self._flat_weights[2*i] + self.vn_std * self.noises[i])
weight.append(self._flat_weights[2*i+1] + self.vn_std * self.noises[i+1])
weight.append(self._flat_weights[2*i+2])
weight.append(self._flat_weights[2*i+3])
else:
weight = self._flat_weights
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.lstm(input, hx, weight, self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.lstm(input, batch_sizes, hx, weight, self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1:]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, nn.utils.rnn.PackedSequence):
output_packed = nn.utils.rnn.PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
return output, self.permute_hidden(hidden, unsorted_indices)
class Embedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx = None):
super(Embedding, self).__init__(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx)
# Variational Noise
self.noise = None
self.vn_std = None
def init_vn(self, vn_std):
# Variational Noise
self.vn_std = vn_std
def sample_synaptic_noise(self, distributed):
# Sample Noise
self.noise = torch.normal(mean=0.0, std=1.0, size=self.weight.size(), device=self.weight.device, dtype=self.weight.dtype)
# Broadcast Noise
if distributed:
torch.distributed.broadcast(self.noise, 0)
def forward(self, input):
# Weight
weight = self.weight
# Add Noise
if self.noise is not None and self.training:
weight = weight + self.vn_std * self.noise
# Apply Weight
return F.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
class IdentityProjection(nn.Module):
def __init__(self, input_dim, output_dim):
super(IdentityProjection, self).__init__()
assert output_dim > input_dim
self.linear = Linear(input_dim, output_dim - input_dim)
def forward(self, x):
# (B, T, Dout - Din)
proj = self.linear(x)
# (B, T, Dout)
x = torch.cat([x, proj], dim=-1)
return x
class DepthwiseSeparableConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(DepthwiseSeparableConv1d, self).__init__()
# Layers
self.layers = nn.Sequential(
Conv1d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, stride=stride),
Conv1d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm1d(out_channels),
Swish()
)
def forward(self, x):
return self.layers(x)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# Base Model
from models.model import Model
# Decoder
from models.decoders import (
RnnDecoder,
TransformerDecoder
)
# Losses
from models.losses import (
LossCE
)
class LanguageModel(Model):
def __init__(self, lm_params, tokenizer_params, training_params, decoding_params, name):
super(LanguageModel, self).__init__(tokenizer_params, training_params, decoding_params, name)
# Language Model
if lm_params["arch"] == "RNN":
self.decoder = RnnDecoder(lm_params)
elif lm_params["arch"] == "Transformer":
self.decoder = TransformerDecoder(lm_params)
else:
raise Exception("Unknown model architecture:", lm_params["arch"])
# FC Layer
self.fc = nn.Linear(lm_params["dim_model"], tokenizer_params["vocab_size"])
# Criterion
self.criterion = LossCE()
# Compile
self.compile(training_params)
def decode(self, x, hidden):
# Text Decoder (1, 1) -> (1, 1, Dlm)
logits, hidden = self.decoder(x, hidden)
# FC Layer (1, 1, Dlm) -> (1, 1, V)
logits = self.fc(logits)
return logits, hidden
def forward(self, batch):
# Unpack Batch
x, x_len, y = batch
# Add blank token
x = torch.nn.functional.pad(x, pad=(1, 0, 0, 0), value=0)
if x_len is not None:
x_len = x_len + 1
# Text Decoder (B, U + 1) -> (B, U + 1, Dlm)
logits, _ = self.decoder(x, None, x_len)
# FC Layer (B, U + 1, Dlm) -> (B, U + 1, V)
logits = self.fc(logits)
return logits
def gready_search_decoding(self, x, x_len):
return [""]
\ No newline at end of file
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# PyTorch
import torch
import torch.nn as nn
# RNN-T Loss
import warp_rnnt
class LossRNNT(nn.Module):
def __init__(self):
super(LossRNNT, self).__init__()
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _ = pred
# Compute Loss
loss = warp_rnnt.rnnt_loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1),
labels=y.int(),
frames_lengths=f_len.int(),
labels_lengths=y_len.int(),
average_frames=False,
reduction='mean',
blank=0,
gather=True)
return loss
class LossCTC(nn.Module):
def __init__(self):
super(LossCTC, self).__init__()
# CTC Loss
self.loss = nn.CTCLoss(blank=0, reduction="none", zero_infinity=False)
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _ = pred
# Compute Loss
loss = self.loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1).transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len).mean()
return loss
class LossInterCTC(nn.Module):
def __init__(self, interctc_lambda):
super(LossInterCTC, self).__init__()
# CTC Loss
self.loss = nn.CTCLoss(blank=0, reduction="none", zero_infinity=False)
# InterCTC Lambda
self.interctc_lambda = interctc_lambda
def forward(self, batch, pred):
# Unpack Batch
x, y, x_len, y_len = batch
# Unpack Predictions
outputs_pred, f_len, _, interctc_probs = pred
# Compute CTC Loss
loss_ctc = self.loss(
log_probs=torch.nn.functional.log_softmax(outputs_pred, dim=-1).transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len)
# Compute Inter Loss
loss_inter = sum(self.loss(
log_probs=interctc_prob.log().transpose(0, 1),
targets=y,
input_lengths=f_len,
target_lengths=y_len) for interctc_prob in interctc_probs) / len(interctc_probs)
# Compute total Loss
loss = (1 - self.interctc_lambda) * loss_ctc + self.interctc_lambda * loss_inter
loss = loss.mean()
return loss
class LossCE(nn.Module):
def __init__(self):
super(LossCE, self).__init__()
# CE Loss
self.loss = nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-1, reduce=None, reduction='mean')
def forward(self, batch, pred):
# Unpack Batch
x, x_len, y = batch
# Unpack Predictions
outputs_pred = pred
# Compute Loss
loss = self.loss(
input=outputs_pred.transpose(1, 2),
target=y)
return loss
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment