Unverified Commit 18dbf036 authored by Sehoon Kim's avatar Sehoon Kim Committed by GitHub
Browse files

Squeezeformer Initial Commit



Initial Commit
Co-authored-by: default avatarAlbert Shaw <ashaw596@gmail.com>
Co-authored-by: default avatarNicholas Lee <caldragon18456@berkeley.edu>
Co-authored-by: default avatarani <aninrusimha@berkeley.edu>
Co-authored-by: default avatardragon18456 <nicholas_lee@berkeley.edu>
parent 5d6f1ae4
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from ..utils import shape_util
class SpecAugmentation(tf.keras.Model):
def __init__(
self,
num_freq_masks=2,
freq_mask_len=27,
num_time_masks=5,
time_mask_prop=0.05,
name='specaug',
**kwargs,
):
super(SpecAugmentation, self).__init__(name=name, **kwargs)
self.num_freq_masks = num_freq_masks
self.freq_mask_len = freq_mask_len
self.num_time_masks = num_time_masks
self.time_mask_prop = time_mask_prop
def time_mask(self, inputs, inputs_len):
time_max = inputs_len
B, T, F = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
t = tf.random.uniform(shape=tf.shape(time_max), minval=0, maxval=self.time_mask_prop)
t = tf.cast(tf.cast(time_max, tf.dtypes.float32) * t, 'int32')
t0 = tf.random.uniform(shape=tf.shape(time_max), minval=0, maxval=1)
t0 = tf.cast(tf.cast(time_max - t, tf.dtypes.float32) * t0, 'int32')
t = tf.repeat(tf.reshape(t, (-1, 1)), T, axis=1)
t0 = tf.repeat(tf.reshape(t0, (-1, 1)), T, axis=1)
indices = tf.repeat(tf.reshape(tf.range(T), (1, -1)), B, axis=0)
left_mask = tf.cast(tf.math.greater_equal(indices, t0), 'float32')
right_mask = tf.cast(tf.math.less(indices, t0 + t), 'float32')
mask = 1.0 - left_mask * right_mask
masked_inputs = inputs * tf.reshape(mask, (B, T, 1, 1))
return masked_inputs
def frequency_mask(self, inputs, inputs_len):
B, T, F = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
f = tf.random.uniform(shape=tf.shape(inputs_len), minval=0, maxval=self.freq_mask_len, dtype='int32')
f0 = tf.random.uniform(shape=tf.shape(inputs_len), minval=0, maxval=1)
f0 = tf.cast(tf.cast(F - f, tf.dtypes.float32) * f0, 'int32')
f = tf.repeat(tf.reshape(f, (-1, 1)), F, axis=1)
f0 = tf.repeat(tf.reshape(f0, (-1, 1)), F, axis=1)
indices = tf.repeat(tf.reshape(tf.range(F), (1, -1)), B, axis=0)
left_mask = tf.cast(tf.math.greater_equal(indices, f0), 'float32')
right_mask = tf.cast(tf.math.less(indices, f0 + f), 'float32')
mask = 1.0 - left_mask * right_mask
masked_inputs = inputs * tf.reshape(mask, (B, 1, F, 1))
return masked_inputs
@tf.function
def call(self, inputs, inputs_len):
masked_inputs = inputs
for _ in range(self.num_time_masks):
masked_inputs = self.time_mask(masked_inputs, inputs_len)
for _ in range(self.num_freq_masks):
masked_inputs = self.frequency_mask(masked_inputs, inputs_len)
return masked_inputs
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ..utils import file_util
class DatasetConfig:
def __init__(self, config: dict = None):
if not config: config = {}
self.stage = config.pop("stage", None)
self.data_paths = file_util.preprocess_paths(config.pop("data_paths", None))
self.tfrecords_dir = file_util.preprocess_paths(config.pop("tfrecords_dir", None), isdir=True)
self.tfrecords_shards = config.pop("tfrecords_shards", 16)
self.shuffle = config.pop("shuffle", False)
self.cache = config.pop("cache", False)
self.drop_remainder = config.pop("drop_remainder", True)
self.buffer_size = config.pop("buffer_size", 10000)
for k, v in config.items(): setattr(self, k, v)
class RunningConfig:
def __init__(self, config: dict = None):
if not config: config = {}
self.batch_size = config.pop("batch_size", 1)
self.accumulation_steps = config.pop("accumulation_steps", 1)
self.num_epochs = config.pop("num_epochs", 20)
for k, v in config.items(): setattr(self, k, v)
class LearningConfig:
def __init__(self, config: dict = None):
if not config: config = {}
self.train_dataset_config = DatasetConfig(config.pop("train_dataset_config", {}))
self.eval_dataset_config = DatasetConfig(config.pop("eval_dataset_config", {}))
self.test_dataset_config = DatasetConfig(config.pop("test_dataset_config", {}))
self.optimizer_config = config.pop("optimizer_config", {})
self.running_config = RunningConfig(config.pop("running_config", {}))
for k, v in config.items(): setattr(self, k, v)
class Config:
""" User config class for training, testing or infering """
def __init__(self, data: Union[str, dict]):
config = data if isinstance(data, dict) else file_util.load_yaml(file_util.preprocess_paths(data))
self.speech_config = config.pop("speech_config", {})
self.decoder_config = config.pop("decoder_config", {})
self.model_config = config.pop("model_config", {})
self.learning_config = LearningConfig(config.pop("learning_config", {}))
for k, v in config.items(): setattr(self, k, v)
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import abc
from typing import Union
import tqdm
import numpy as np
import tensorflow as tf
from ..featurizers.speech_featurizers import (
load_and_convert_to_wav,
read_raw_audio,
tf_read_raw_audio,
TFSpeechFeaturizer
)
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils import feature_util, file_util, math_util, data_util
logger = tf.get_logger()
BUFFER_SIZE = 10000
AUTOTUNE = tf.data.experimental.AUTOTUNE
class BaseDataset(metaclass=abc.ABCMeta):
""" Based dataset for all models """
def __init__(
self,
data_paths: list,
cache: bool = False,
shuffle: bool = False,
buffer_size: int = BUFFER_SIZE,
indefinite: bool = False,
drop_remainder: bool = True,
stage: str = "train",
**kwargs,
):
self.data_paths = data_paths or []
if not isinstance(self.data_paths, list):
raise ValueError('data_paths must be a list of string paths')
self.cache = cache # whether to cache transformed dataset to memory
self.shuffle = shuffle # whether to shuffle tf.data.Dataset
if buffer_size <= 0 and shuffle:
raise ValueError("buffer_size must be positive when shuffle is on")
self.buffer_size = buffer_size # shuffle buffer size
self.stage = stage # for defining tfrecords files
self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training
self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch
self.total_steps = None # for better training visualization
@abc.abstractmethod
def parse(self, *args, **kwargs):
raise NotImplementedError()
@abc.abstractmethod
def create(self, batch_size):
raise NotImplementedError()
class ASRDataset(BaseDataset):
""" Dataset for ASR using Generator """
def __init__(
self,
stage: str,
speech_featurizer: TFSpeechFeaturizer,
text_featurizer: TextFeaturizer,
data_paths: list,
cache: bool = False,
shuffle: bool = False,
indefinite: bool = False,
drop_remainder: bool = True,
buffer_size: int = BUFFER_SIZE,
input_padding_length: int = 3300,
label_padding_length: int = 530,
**kwargs,
):
super().__init__(
data_paths=data_paths,
cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size,
drop_remainder=drop_remainder, indefinite=indefinite
)
self.speech_featurizer = speech_featurizer
self.text_featurizer = text_featurizer
self.input_padding_length = input_padding_length
self.label_padding_length = label_padding_length
# -------------------------------- ENTRIES -------------------------------------
def read_entries(self):
if hasattr(self, "entries") and len(self.entries) > 0: return
self.entries = []
for file_path in self.data_paths:
logger.info(f"Reading {file_path} ...")
with tf.io.gfile.GFile(file_path, "r") as f:
temp_lines = f.read().splitlines()
# Skip the header of tsv file
self.entries += temp_lines[1:]
# The files is "\t" seperated
self.entries = [line.split("\t", 2) for line in self.entries]
for i, line in enumerate(self.entries):
self.entries[i][-1] = " ".join([str(x) for x in self.text_featurizer.extract(line[-1]).numpy()])
self.entries = np.array(self.entries)
if self.shuffle: np.random.shuffle(self.entries) # Mix transcripts.tsv
self.total_steps = len(self.entries)
# -------------------------------- LOAD AND PREPROCESS -------------------------------------
def generator(self):
for path, _, indices in self.entries:
audio = load_and_convert_to_wav(path).numpy()
yield bytes(path, "utf-8"), audio, bytes(indices, "utf-8")
def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
with tf.device("/CPU:0"):
signal = tf_read_raw_audio(audio, self.speech_featurizer.sample_rate)
features = self.speech_featurizer.tf_extract(signal)
input_length = tf.cast(tf.shape(features)[0], tf.int32)
label = tf.strings.to_number(tf.strings.split(indices), out_type=tf.int32)
label_length = tf.cast(tf.shape(label)[0], tf.int32)
prediction = self.text_featurizer.prepand_blank(label)
prediction_length = tf.cast(tf.shape(prediction)[0], tf.int32)
return path, features, input_length, label, label_length, prediction, prediction_length
def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
"""
Returns:
path, features, input_lengths, labels, label_lengths, pred_inp
"""
data = self.tf_preprocess(path, audio, indices)
_, features, input_length, label, label_length, prediction, prediction_length = data
return (
data_util.create_inputs(
inputs=features,
inputs_length=input_length,
predictions=prediction,
predictions_length=prediction_length
),
data_util.create_labels(
labels=label,
labels_length=label_length
)
)
def process(self, dataset, batch_size):
dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE)
self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder)
if self.cache:
dataset = dataset.cache()
if self.shuffle:
dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True)
if self.indefinite and self.total_steps:
dataset = dataset.repeat()
dataset = dataset.padded_batch(
batch_size=batch_size,
padded_shapes=(
data_util.create_inputs(
inputs=tf.TensorShape([self.input_padding_length, 80, 1]),
inputs_length=tf.TensorShape([]),
predictions=tf.TensorShape([self.label_padding_length]),
predictions_length=tf.TensorShape([])
),
data_util.create_labels(
labels=tf.TensorShape([self.label_padding_length]),
labels_length=tf.TensorShape([])
),
),
padding_values=(
data_util.create_inputs(
inputs=0.0,
inputs_length=0,
predictions=self.text_featurizer.blank,
predictions_length=0
),
data_util.create_labels(
labels=self.text_featurizer.blank,
labels_length=0
)
),
drop_remainder=self.drop_remainder
)
# PREFETCH to improve speed of input length
dataset = dataset.prefetch(AUTOTUNE)
return dataset
def create(self, batch_size: int):
self.read_entries()
if not self.total_steps or self.total_steps == 0: return print("Couldn't create")
dataset = tf.data.Dataset.from_generator(
self.generator,
output_types=(tf.string, tf.string, tf.string),
output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]))
)
return self.process(dataset, batch_size)
class ASRSliceDataset(ASRDataset):
""" Dataset for ASR using Slice """
@staticmethod
def load(record: tf.Tensor):
def fn(path: bytes): return load_and_convert_to_wav(path.decode("utf-8")).numpy()
audio = tf.numpy_function(fn, inp=[record[0]], Tout=tf.string)
return record[0], audio, record[2]
def create(self, batch_size: int):
self.read_entries()
if not self.total_steps or self.total_steps == 0: return None
dataset = tf.data.Dataset.from_tensor_slices(self.entries)
dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE)
return self.process(dataset, batch_size)
def preprocess_dataset(self, tfrecord_path, shard_size=0, max_len=None):
self.read_entries()
if not self.total_steps or self.total_steps == 0: return None
logger.info(f"Preprocess dataset")
dataset = tf.data.Dataset.from_tensor_slices(self.entries)
dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE)
self.create_preprocessed_tfrecord(dataset, tfrecord_path, shard_size, max_len)
# Copyright 2020 Huy Le Nguyen (@usimarit) and Huy Phan (@pquochuy)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import io
import abc
import math
from typing import Union
import numpy as np
import librosa
import soundfile as sf
import tensorflow as tf
import tensorflow_io as tfio
from ..utils import math_util, env_util
def load_and_convert_to_wav(path: str) -> tf.Tensor:
wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True)
return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate)
def read_raw_audio(audio: Union[str, bytes, np.ndarray], sample_rate=16000) -> np.ndarray:
if isinstance(audio, str):
wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate, mono=True)
elif isinstance(audio, bytes):
wave, sr = sf.read(io.BytesIO(audio))
if wave.ndim > 1: wave = np.mean(wave, axis=-1)
wave = np.asfortranarray(wave)
if sr != sample_rate: wave = librosa.resample(wave, sr, sample_rate)
elif isinstance(audio, np.ndarray):
if audio.ndim > 1: ValueError("input audio must be single channel")
return audio
else:
raise ValueError("input audio must be either a path or bytes")
return wave
def tf_read_raw_audio(audio: tf.Tensor, sample_rate=16000) -> tf.Tensor:
wave, rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1)
if not env_util.has_devices("TPU"):
resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate)
return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal
return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal
def slice_signal(signal, window_size, stride=0.5) -> np.ndarray:
""" Return windows of the given signal by sweeping in stride fractions of window """
assert signal.ndim == 1, signal.ndim
n_samples = signal.shape[0]
offset = int(window_size * stride)
slices = []
for beg_i, end_i in zip(range(0, n_samples, offset),
range(window_size, n_samples + offset,
offset)):
slice_ = signal[beg_i:end_i]
if slice_.shape[0] < window_size:
slice_ = np.pad(
slice_, (0, window_size - slice_.shape[0]), 'constant', constant_values=0.0)
if slice_.shape[0] == window_size:
slices.append(slice_)
return np.array(slices, dtype=np.float32)
def tf_merge_slices(slices: tf.Tensor) -> tf.Tensor:
# slices shape = [batch, window_size]
return tf.keras.backend.flatten(slices) # return shape = [-1, ]
def merge_slices(slices: np.ndarray) -> np.ndarray:
# slices shape = [batch, window_size]
return np.reshape(slices, [-1])
def tf_normalize_audio_features(audio_feature: tf.Tensor, per_frame=False) -> tf.Tensor:
"""
TF Mean and variance features normalization
Args:
audio_feature: tf.Tensor with shape [T, F]
Returns:
normalized audio features with shape [T, F]
"""
axis = 1 if per_frame else None
mean = tf.reduce_mean(audio_feature, axis=axis, keepdims=True)
std_dev = tf.math.sqrt(tf.math.reduce_variance(audio_feature, axis=axis, keepdims=True) + 1e-9)
return (audio_feature - mean) / std_dev
def tf_normalize_signal(signal: tf.Tensor) -> tf.Tensor:
"""
TF Normailize signal to [-1, 1] range
Args:
signal: tf.Tensor with shape [None]
Returns:
normalized signal with shape [None]
"""
gain = 1.0 / (tf.reduce_max(tf.abs(signal), axis=-1) + 1e-9)
return signal * gain
def tf_preemphasis(signal: tf.Tensor, coeff=0.97):
"""
TF Pre-emphasis
Args:
signal: tf.Tensor with shape [None]
coeff: Float that indicates the preemphasis coefficient
Returns:
pre-emphasized signal with shape [None]
"""
if not coeff or coeff <= 0.0: return signal
s0 = tf.expand_dims(signal[0], axis=-1)
s1 = signal[1:] - coeff * signal[:-1]
return tf.concat([s0, s1], axis=-1)
def tf_depreemphasis(signal: tf.Tensor, coeff=0.97) -> tf.Tensor:
"""
TF Depreemphasis
Args:
signal: tf.Tensor with shape [B, None]
coeff: Float that indicates the preemphasis coefficient
Returns:
depre-emphasized signal with shape [B, None]
"""
if not coeff or coeff <= 0.0: return signal
def map_fn(elem):
x = tf.expand_dims(elem[0], axis=-1)
for n in range(1, elem.shape[0], 1):
current = coeff * x[n - 1] + elem[n]
x = tf.concat([x, [current]], axis=0)
return x
return tf.map_fn(map_fn, signal)
class TFSpeechFeaturizer(metaclass=abc.ABCMeta):
def __init__(self, speech_config: dict):
"""
speech_config = {
"sample_rate": int,
"frame_ms": int,
"stride_ms": int,
"num_feature_bins": int,
"feature_type": str,
"delta": bool,
"delta_delta": bool,
"pitch": bool,
"normalize_signal": bool,
"normalize_feature": bool,
"normalize_per_frame": bool
}
"""
# Samples
self.sample_rate = speech_config.get("sample_rate", 16000)
self.frame_length = int(self.sample_rate * (speech_config.get("frame_ms", 25) / 1000))
self.frame_step = int(self.sample_rate * (speech_config.get("stride_ms", 10) / 1000))
# Features
self.num_feature_bins = speech_config.get("num_feature_bins", 80)
self.feature_type = speech_config.get("feature_type", "log_mel_spectrogram")
self.preemphasis = speech_config.get("preemphasis", None)
self.top_db = speech_config.get("top_db", 80.0)
# Normalization
self.normalize_signal = speech_config.get("normalize_signal", True)
self.normalize_feature = speech_config.get("normalize_feature", True)
self.normalize_per_frame = speech_config.get("normalize_per_frame", False)
self.center = speech_config.get("center", True)
# Length
self.max_length = 0
@property
def shape(self) -> list:
length = self.max_length if self.max_length > 0 else None
return [length, self.num_feature_bins, 1]
@property
def nfft(self) -> int:
""" Number of FFT """
return 2 ** (self.frame_length - 1).bit_length()
def get_length_from_duration(self, duration):
nsamples = math.ceil(float(duration) * self.sample_rate)
if self.center: nsamples += self.nfft
return 1 + (nsamples - self.nfft) // self.frame_step # https://www.tensorflow.org/api_docs/python/tf/signal/frame
def update_length(self, length: int):
self.max_length = max(self.max_length, length)
def reset_length(self):
self.max_length = 0
def stft(self, signal):
if self.center: signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT")
window = tf.signal.hann_window(self.frame_length, periodic=True)
left_pad = (self.nfft - self.frame_length) // 2
right_pad = self.nfft - self.frame_length - left_pad
window = tf.pad(window, [[left_pad, right_pad]])
framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.frame_step)
framed_signals *= window
return tf.square(tf.abs(tf.signal.rfft(framed_signals, [self.nfft])))
def power_to_db(self, S, amin=1e-10):
log_spec = 10.0 * math_util.log10(tf.maximum(amin, S))
log_spec -= 10.0 * math_util.log10(tf.maximum(amin, 1.0))
if self.top_db is not None:
if self.top_db < 0:
raise ValueError('top_db must be non-negative')
log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - self.top_db)
return log_spec
def extract(self, signal: np.ndarray) -> np.ndarray:
signal = np.asfortranarray(signal)
features = self.tf_extract(tf.convert_to_tensor(signal, dtype=tf.float32))
return features.numpy()
def tf_extract(self, signal: tf.Tensor) -> tf.Tensor:
"""
Extract speech features from signals (for using in tflite)
Args:
signal: tf.Tensor with shape [None]
Returns:
features: tf.Tensor with shape [T, F, 1]
"""
if self.normalize_signal:
signal = tf_normalize_signal(signal)
signal = tf_preemphasis(signal, self.preemphasis)
if self.feature_type == "log_mel_spectrogram":
features = self.compute_log_mel_spectrogram(signal)
else:
raise ValueError("feature_type must be 'log_mel_spectrogram'")
features = tf.expand_dims(features, axis=-1)
if self.normalize_feature:
features = tf_normalize_audio_features(features, per_frame=self.normalize_per_frame)
return features
def compute_log_mel_spectrogram(self, signal):
spectrogram = self.stft(signal)
linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=self.num_feature_bins,
num_spectrogram_bins=spectrogram.shape[-1],
sample_rate=self.sample_rate,
lower_edge_hertz=0.0, upper_edge_hertz=(self.sample_rate / 2)
)
mel_spectrogram = tf.tensordot(spectrogram, linear_to_weight_matrix, 1)
return self.power_to_db(mel_spectrogram)
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import abc
import codecs
import unicodedata
from multiprocessing import cpu_count
import sentencepiece as sp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tds
from ..utils import file_util
ENGLISH_CHARACTERS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
class TextFeaturizer(metaclass=abc.ABCMeta):
def __init__(self):
self.scorer = None
self.blank = None
self.tokens2indices = {}
self.tokens = []
self.num_classes = None
self.max_length = 0
@property
def shape(self) -> list:
return [self.max_length if self.max_length > 0 else None]
@property
def prepand_shape(self) -> list:
return [self.max_length + 1 if self.max_length > 0 else None]
def update_length(self, length: int):
self.max_length = max(self.max_length, length)
def reset_length(self):
self.max_length = 0
def preprocess_text(self, text):
text = unicodedata.normalize("NFC", text.lower())
return text.strip("\n") # remove trailing newline
def add_scorer(self, scorer: any = None):
""" Add scorer to this instance """
self.scorer = scorer
def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor:
"""
Remove -1 in indices by replacing them with blanks
Args:
indices (tf.Tensor): shape any
Returns:
tf.Tensor: normalized indices with shape same as indices
"""
with tf.name_scope("normalize_indices"):
minus_one = -1 * tf.ones_like(indices, dtype=tf.int32)
blank_like = self.blank * tf.ones_like(indices, dtype=tf.int32)
return tf.where(indices == minus_one, blank_like, indices)
def prepand_blank(self, text: tf.Tensor) -> tf.Tensor:
""" Prepand blank index for transducer models """
return tf.concat([[self.blank], text], axis=0)
@abc.abstractclassmethod
def extract(self, text):
raise NotImplementedError()
@abc.abstractclassmethod
def iextract(self, indices):
raise NotImplementedError()
@abc.abstractclassmethod
def indices2upoints(self, indices):
raise NotImplementedError()
class SentencePieceFeaturizer(TextFeaturizer):
"""
Extract text feature based on sentence piece package.
"""
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 1
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 2
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 3
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 0 # unused, by default
def __init__(self, decoder_config: dict, model=None):
super(SentencePieceFeaturizer, self).__init__()
self.vocabulary = decoder_config['vocabulary']
self.model = self.__load_model() if model is None else model
self.blank = 0 # treats blank as 0 (pad)
# vocab size
self.num_classes = self.model.get_piece_size()
self.__init_vocabulary()
def __load_model(self):
filename_prefix = os.path.splitext(self.vocabulary)[0]
processor = sp.SentencePieceProcessor()
processor.load(filename_prefix + ".model")
return processor
def __init_vocabulary(self):
self.tokens = []
for idx in range(1, self.num_classes):
self.tokens.append(self.model.decode_ids([idx]))
self.non_blank_tokens = self.tokens.copy()
self.tokens.insert(0, "")
self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8")
self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length]
@classmethod
def load_from_file(cls, decoder_config: dict, filename: str = None):
if filename is not None:
filename_prefix = os.path.splitext(file_util.preprocess_paths(filename))[0]
else:
filename_prefix = decoder_config.get("output_path_prefix", None)
processor = sp.SentencePieceProcessor()
processor.load(filename_prefix + ".model")
return cls(decoder_config, processor)
def extract(self, text: str) -> tf.Tensor:
"""
Convert string to a list of integers
# encode: text => id
sp.encode_as_pieces('This is a test') --> ['▁This', '▁is', '▁a', '▁t', 'est']
sp.encode_as_ids('This is a test') --> [209, 31, 9, 375, 586]
Args:
text: string (sequence of characters)
Returns:
sequence of ints in tf.Tensor
"""
text = self.preprocess_text(text)
text = text.strip() # remove trailing space
indices = self.model.encode_as_ids(text)
return tf.convert_to_tensor(indices, dtype=tf.int32)
def iextract(self, indices: tf.Tensor) -> tf.Tensor:
"""
Convert list of indices to string
# decode: id => text
sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']) --> This is a test
sp.decode_ids([209, 31, 9, 375, 586]) --> This is a test
Args:
indices: tf.Tensor with dim [B, None]
Returns:
transcripts: tf.Tensor of dtype tf.string with dim [B]
"""
indices = self.normalize_indices(indices)
with tf.device("/CPU:0"): # string data is not supported on GPU
def decode(x):
if x[0] == self.blank: x = x[1:]
return self.model.decode_ids(x.tolist())
text = tf.map_fn(
lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string),
indices,
fn_output_signature=tf.TensorSpec([], dtype=tf.string)
)
return text
@tf.function(
input_signature=[
tf.TensorSpec([None], dtype=tf.int32)
]
)
def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
"""
Transform Predicted Indices to Unicode Code Points (for using tflite)
Args:
indices: tf.Tensor of Classes in shape [None]
Returns:
unicode code points transcript with dtype tf.int32 and shape [None]
"""
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
class CtcLoss(tf.keras.losses.Loss):
def __init__(self, blank=0, name=None):
super(CtcLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
self.blank = blank
def call(self, y_true, y_pred):
loss = ctc_loss(
y_pred=y_pred["logits"],
input_length=y_pred["logits_length"],
y_true=y_true["labels"],
label_length=y_true["labels_length"],
blank=self.blank,
name=self.name
)
return tf.nn.compute_average_loss(loss)
@tf.function
def ctc_loss(y_true, y_pred, input_length, label_length, blank, name=None):
return tf.nn.ctc_loss(
labels=tf.cast(y_true, tf.int32),
logit_length=tf.cast(input_length, tf.int32),
logits=tf.cast(y_pred, tf.float32),
label_length=tf.cast(label_length, tf.int32),
logits_time_major=False,
blank_index=blank,
name=name
)
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras import mixed_precision as mxp
from ..utils import file_util, env_util
class BaseModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._metrics = {}
self.use_loss_scale = False
def save(
self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
):
with file_util.save_file(filepath) as path:
super().save(
filepath=path,
overwrite=overwrite,
include_optimizer=include_optimizer,
save_format=save_format,
signatures=signatures,
options=options,
save_traces=save_traces,
)
def save_weights(
self,
filepath,
overwrite=True,
save_format=None,
options=None,
):
with file_util.save_file(filepath) as path:
super().save_weights(
filepath=path,
overwrite=overwrite,
save_format=save_format,
options=options,
)
def load_weights(
self,
filepath,
by_name=False,
skip_mismatch=False,
options=None,
):
with file_util.read_file(filepath) as path:
super().load_weights(
filepath=path,
by_name=by_name,
skip_mismatch=skip_mismatch,
options=options,
)
@property
def metrics(self):
return self._metrics.values()
def add_metric(self, metric: tf.keras.metrics.Metric):
self._metrics[metric.name] = metric
def make(self, *args, **kwargs):
""" Custom function for building model (uses self.build so cannot overwrite that function) """
raise NotImplementedError()
def compile(self, loss, optimizer, run_eagerly=None, **kwargs):
if not env_util.has_devices("TPU"):
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic")
self.use_loss_scale = True
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
self._metrics = {loss_metric.name: loss_metric}
super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs)
# -------------------------------- STEP FUNCTIONS -------------------------------------
def gradient_step(self, inputs, y_true):
with tf.GradientTape() as tape:
y_pred = self(inputs, training=True)
loss = self.loss(y_true, y_pred)
if self.use_loss_scale:
scaled_loss = self.optimizer.get_scaled_loss(loss)
if self.use_loss_scale:
gradients = tape.gradient(scaled_loss, self.trainable_weights)
gradients = self.optimizer.get_unscaled_gradients(gradients)
else:
gradients = tape.gradient(loss, self.trainable_weights)
return loss, y_pred, gradients
def train_step(self, batch):
"""
Args:
batch ([tf.Tensor]): a batch of training data
Returns:
Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric
"""
inputs, y_true = batch
loss, y_pred, gradients = self.gradient_step(inputs, y_true)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self._metrics["loss"].update_state(loss)
if 'step_loss' in self._metrics:
self._metrics['step_loss'].update_state(loss)
if 'WER' in self._metrics:
self._metrics['WER'].update_state(y_true, y_pred)
if 'labels' in self._metrics:
self._metrics['labels'].update_state(y_true)
if 'logits' in self._metrics:
self._metrics['logits'].update_state(y_pred)
if 'logits_len' in self._metrics:
self._metrics['logits_len'].update_state(y_pred)
return {m.name: m.result() for m in self.metrics}
def test_step(self, batch):
"""
Args:
batch ([tf.Tensor]: a batch of validation data
Returns:
Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric prefixed with "val_"
"""
inputs, y_true = batch
y_pred = self(inputs, training=False)
loss = self.loss(y_true, y_pred)
self._metrics["loss"].update_state(loss)
if 'step_loss' in self._metrics:
self._metrics['step_loss'].update_state(loss)
if 'WER' in self._metrics:
self._metrics['WER'].update_state(y_true, y_pred)
if 'labels' in self._metrics:
self._metrics['labels'].update_state(y_true)
if 'logits' in self._metrics:
self._metrics['logits'].update_state(y_pred)
if 'logits_len' in self._metrics:
self._metrics['logits_len'].update_state(y_pred)
return {m.name: m.result() for m in self.metrics}
def predict_step(self, batch):
"""
Args:
batch ([tf.Tensor]): a batch of testing data
Returns:
[tf.Tensor]: stacked tensor of shape [B, 3] with each row is the text [truth, greedy, beam_search]
"""
inputs, y_true = batch
labels = self.text_featurizer.iextract(y_true["labels"])
greedy_decoding = self.recognize(inputs)
if self.text_featurizer.decoder_config.beam_width == 0:
beam_search_decoding = tf.map_fn(lambda _: tf.convert_to_tensor("", dtype=tf.string), labels)
else:
beam_search_decoding = self.recognize_beam(inputs)
return tf.stack([labels, greedy_decoding, beam_search_decoding], axis=-1)
# -------------------------------- INFERENCE FUNCTIONS -------------------------------------
def recognize(self, *args, **kwargs):
""" Greedy decoding function that used in self.predict_step """
raise NotImplementedError()
def recognize_beam(self, *args, **kwargs):
""" Beam search decoding function that used in self.predict_step """
raise NotImplementedError()
import tensorflow as tf
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.framework import ops
from tensorflow.python.eager import def_function
from .ctc import CtcModel
from .conformer_encoder import ConformerEncoder
from ..augmentations.augmentation import SpecAugmentation
from ..utils import math_util
from ..utils.training_utils import (
_minimum_control_deps,
reduce_per_replica,
write_scalar_summaries,
)
class ConformerCtc(CtcModel):
def __init__(
self,
vocabulary_size: int,
encoder_subsampling: dict,
encoder_dmodel: int = 144,
encoder_num_blocks: int = 16,
encoder_head_size: int = 36,
encoder_num_heads: int = 4,
encoder_mha_type: str = "relmha",
encoder_kernel_size: int = 32,
encoder_fc_factor: float = 0.5,
encoder_dropout: float = 0,
encoder_time_reduce_idx : list = None,
encoder_time_recover_idx : list = None,
encoder_conv_use_glu: bool = False,
encoder_ds_subsample: bool = False,
encoder_no_post_ln: bool = False,
encoder_adaptive_scale: bool = False,
encoder_fixed_arch: list = None,
augmentation_config=None,
name: str = "conformer",
**kwargs,
) -> object:
assert encoder_dmodel == encoder_num_heads * encoder_head_size
if not isinstance(encoder_fixed_arch[0], list):
encoder_fixed_arch = [encoder_fixed_arch] * encoder_num_blocks
super().__init__(
encoder=ConformerEncoder(
subsampling=encoder_subsampling,
dmodel=encoder_dmodel,
num_blocks=encoder_num_blocks,
head_size=encoder_head_size,
num_heads=encoder_num_heads,
mha_type=encoder_mha_type,
kernel_size=encoder_kernel_size,
fc_factor=encoder_fc_factor,
dropout=encoder_dropout,
time_reduce_idx=encoder_time_reduce_idx,
time_recover_idx=encoder_time_recover_idx,
conv_use_glu=encoder_conv_use_glu,
ds_subsample=encoder_ds_subsample,
no_post_ln=encoder_no_post_ln,
adaptive_scale=encoder_adaptive_scale,
fixed_arch=encoder_fixed_arch,
name=f"{name}_encoder",
),
decoder=tf.keras.layers.Conv1D(
filters=vocabulary_size, kernel_size=1,
strides=1, padding="same",
name=f"{name}_logits"
),
augmentation = SpecAugmentation(
num_freq_masks=augmentation_config['freq_masking']['num_masks'],
freq_mask_len=augmentation_config['freq_masking']['mask_factor'],
num_time_masks=augmentation_config['time_masking']['num_masks'],
time_mask_prop=augmentation_config['time_masking']['p_upperbound'],
name=f"{name}_specaug"
) if augmentation_config is not None else None,
vocabulary_size=vocabulary_size,
name=name,
**kwargs
)
self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor
self.dmodel = encoder_dmodel
# The following functions override the original function
# in order to gather the outputs from multiple TPU cores
def make_train_function(self):
if self.train_function is not None:
return self.train_function
def step_function(model, iterator):
"""Runs a single training step."""
def run_step(data):
outputs = model.train_step(data)
# Ensure counter is updated only if `train_step` succeeds.
with ops.control_dependencies(_minimum_control_deps(outputs)):
model._train_counter.assign_add(1) # pylint: disable=protected-access
return outputs
data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(outputs, self.distribute_strategy)
write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access
return outputs
if self._steps_per_execution.numpy().item() == 1:
def train_function(iterator):
"""Runs a training execution with one step."""
return step_function(self, iterator)
else:
def train_function(iterator):
"""Runs a training execution with multiple steps."""
for _ in math_ops.range(self._steps_per_execution):
outputs = step_function(self, iterator)
return outputs
if not self.run_eagerly:
train_function = def_function.function(
train_function, experimental_relax_shapes=True)
self.train_function = train_function
if self._cluster_coordinator:
self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
train_function, args=(iterator,))
return self.train_function
def make_test_function(self):
if self.test_function is not None:
return self.test_function
def step_function(model, iterator):
"""Runs a single evaluation step."""
def run_step(data):
outputs = model.test_step(data)
# Ensure counter is updated only if `test_step` succeeds.
with ops.control_dependencies(_minimum_control_deps(outputs)):
model._test_counter.assign_add(1) # pylint: disable=protected-access
return outputs
data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(outputs, self.distribute_strategy)
return outputs
if self._steps_per_execution.numpy().item() == 1:
def test_function(iterator):
"""Runs an evaluation execution with one step."""
return step_function(self, iterator)
else:
def test_function(iterator):
"""Runs an evaluation execution with multiple steps."""
for _ in math_ops.range(self._steps_per_execution):
outputs = step_function(self, iterator)
return outputs
if not self.run_eagerly:
test_function = def_function.function(test_function, experimental_relax_shapes=True)
self.test_function = test_function
if self._cluster_coordinator:
self.test_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
test_function, args=(iterator,))
return self.test_function
class ConformerCtcAccumulate(ConformerCtc):
def __init__(self, n_gradients: int = 1, **kwargs) -> object:
super().__init__(**kwargs)
self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor
self.n_gradients = tf.constant(n_gradients, dtype=tf.int32, name="conformer/num_accumulated_gradients")
self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="conformer/accumulate_step")
def make(self, input_shape, batch_size=None):
super().make(input_shape, batch_size)
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False, name=f"{v.name}/cached_accumulated_gradient") for v in self.trainable_variables
]
def train_step(self, batch):
"""
Args:
batch ([tf.Tensor]): a batch of training data
Returns:
Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric
"""
self.n_acum_step.assign_add(1)
inputs, y_true = batch
loss, y_pred, gradients = self.gradient_step(inputs, y_true)
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(gradients[i] / tf.cast(self.n_gradients, tf.float32))
tf.cond(tf.equal(self.n_acum_step, self.n_gradients), self.apply_accu_gradients, lambda: None)
self._metrics["loss"].update_state(loss)
if 'WER' in self._metrics:
self._metrics['WER'].update_state(y_true, y_pred)
return {m.name: m.result() for m in self.metrics}
def apply_accu_gradients(self):
# Apply accumulated gradients
self.optimizer.apply_gradients(zip(self.gradient_accumulation,
self.trainable_variables))
# Reset
self.n_acum_step.assign(0)
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
)
This diff is collapsed.
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
import numpy as np
import tensorflow as tf
from .base_model import BaseModel
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils import math_util, shape_util, data_util
from ..losses.ctc_loss import CtcLoss
logger = tf.get_logger()
class CtcModel(BaseModel):
def __init__(
self,
encoder: tf.keras.Model,
decoder: Union[tf.keras.Model, tf.keras.layers.Layer] = None,
augmentation: tf.keras.Model = None,
vocabulary_size: int = None,
**kwargs,
):
super().__init__(**kwargs)
self.encoder = encoder
if decoder is None:
assert vocabulary_size is not None, "vocabulary_size must be set"
self.decoder = tf.keras.layers.Dense(units=vocabulary_size, name=f"{self.name}_logits")
else:
self.decoder = decoder
self.augmentation = augmentation
self.time_reduction_factor = 1
def make(self, input_shape, batch_size=None):
inputs = tf.keras.Input(input_shape, batch_size=batch_size, dtype=tf.float32)
inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32)
self(
data_util.create_inputs(
inputs=inputs,
inputs_length=inputs_length
),
training=False
)
def compile(self, optimizer, blank=0, run_eagerly=None, **kwargs):
loss = CtcLoss(blank=blank)
super().compile(loss=loss, optimizer=optimizer, run_eagerly=run_eagerly, **kwargs)
def add_featurizers(
self,
speech_featurizer: TFSpeechFeaturizer,
text_featurizer: TextFeaturizer,
):
self.speech_featurizer = speech_featurizer
self.text_featurizer = text_featurizer
def call(self, inputs, training=False, **kwargs):
x, x_length = inputs["inputs"], inputs["inputs_length"]
if training and self.augmentation is not None:
x = self.augmentation(x, x_length)
logits = self.encoder(x, x_length, training=training, **kwargs)
logits = self.decoder(logits, training=training, **kwargs)
return data_util.create_logits(
logits=logits,
logits_length=math_util.get_reduced_length(x_length, self.time_reduction_factor)
)
# -------------------------------- GREEDY -------------------------------------
@tf.function
def recognize_from_logits(self, logits: tf.Tensor, lengths: tf.Tensor):
probs = tf.nn.softmax(logits)
# blank is in the first index of `probs`, where `ctc_greedy_decoder` supposes it to be in the last index.
# threfore, we move the first column to the last column to be compatible with `ctc_greedy_decoder`
probs = tf.concat([probs[:, :, 1:], tf.expand_dims(probs[:, :, 0], -1)], axis=-1)
def _map(elems): return tf.numpy_function(self._perform_greedy, inp=[elems[0], elems[1]], Tout=tf.string)
return tf.map_fn(_map, (probs, lengths), fn_output_signature=tf.TensorSpec([], dtype=tf.string))
@tf.function
def recognize(self, inputs: Dict[str, tf.Tensor]):
logits = self(inputs, training=False)
probs = tf.nn.softmax(logits["logits"])
# send the first index (skip token) to the last index
# for compatibility with the ctc_decoders library
probs = tf.concat([probs[:, :, 1:], tf.expand_dims(probs[:, :, 0], -1)], axis=-1)
lengths = logits["logits_length"]
def map_fn(elem): return tf.numpy_function(self._perform_greedy, inp=[elem[0], elem[1]], Tout=tf.string)
return tf.map_fn(map_fn, [probs, lengths], fn_output_signature=tf.TensorSpec([], dtype=tf.string))
def _perform_greedy(self, probs: np.ndarray, length):
from ctc_decoders import ctc_greedy_decoder
decoded = ctc_greedy_decoder(probs[:length], vocabulary=self.text_featurizer.non_blank_tokens)
return tf.convert_to_tensor(decoded, dtype=tf.string)
# -------------------------------- BEAM SEARCH -------------------------------------
@tf.function
def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
logits = self(inputs, training=False)
probs = tf.nn.softmax(logits["logits"])
def map_fn(prob): return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string)
return tf.map_fn(map_fn, probs, dtype=tf.string)
def _perform_beam_search(self, probs: np.ndarray, lm: bool = False):
from ctc_decoders import ctc_beam_search_decoder
decoded = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=self.text_featurizer.non_blank_tokens,
beam_size=self.text_featurizer.decoder_config.beam_width,
ext_scoring_func=self.text_featurizer.scorer if lm else None
)
decoded = decoded[0][-1]
return tf.convert_to_tensor(decoded, dtype=tf.string)
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
class GLU(tf.keras.layers.Layer):
def __init__(self,
axis=-1,
name="glu_activation",
**kwargs):
super(GLU, self).__init__(name=name, **kwargs)
self.axis = axis
def call(self, inputs, **kwargs):
a, b = tf.split(inputs, 2, axis=self.axis)
b = tf.nn.sigmoid(b)
return tf.multiply(a, b)
def get_config(self):
conf = super(GLU, self).get_config()
conf.update({"axis": self.axis})
return conf
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import tensorflow as tf
from src.utils import shape_util
logger = tf.get_logger()
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(
self,
num_heads,
head_size,
output_size: int = None,
dropout: float = 0.0,
use_projection_bias: bool = True,
return_attn_coef: bool = False,
**kwargs,
):
super(MultiHeadAttention, self).__init__(**kwargs)
if output_size is not None and output_size < 1:
raise ValueError("output_size must be a positive number")
self.head_size = head_size
self.num_heads = num_heads
self.output_size = output_size
self.use_projection_bias = use_projection_bias
self.return_attn_coef = return_attn_coef
self.dropout = tf.keras.layers.Dropout(dropout, name="dropout")
self._droput_rate = dropout
def build(self, input_shape):
num_query_features = input_shape[0][-1]
num_key_features = input_shape[1][-1]
num_value_features = (
input_shape[2][-1] if len(input_shape) > 2 else num_key_features
)
output_size = (
self.output_size if self.output_size is not None else num_value_features
)
input_max = (self.num_heads * self.head_size) ** -0.5
self.query = tf.keras.layers.Dense(
self.num_heads * self.head_size, activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
self.key = tf.keras.layers.Dense(
self.num_heads * self.head_size, activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
self.value = tf.keras.layers.Dense(
self.num_heads * self.head_size, activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
self.projection_kernel = self.add_weight(
name="projection_kernel",
shape=[self.num_heads, self.head_size, output_size],
initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
if self.use_projection_bias:
self.projection_bias = self.add_weight(
name="projection_bias",
shape=[output_size],
initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
else:
self.projection_bias = None
def call_qkv(self, query, key, value, training=False):
# verify shapes
if key.shape[-2] != value.shape[-2]:
raise ValueError(
"the number of elements in 'key' must be equal to "
"the same as the number of elements in 'value'"
)
# Linear transformations
query = self.query(query)
B, T, E = shape_util.shape_list(query)
query = tf.reshape(query, [B, T, self.num_heads, self.head_size])
key = self.key(key)
B, T, E = shape_util.shape_list(key)
key = tf.reshape(key, [B, T, self.num_heads, self.head_size])
value = self.value(value)
B, T, E = shape_util.shape_list(value)
value = tf.reshape(value, [B, T, self.num_heads, self.head_size])
return query, key, value
def call_attention(self, query, key, value, logits, training=False, mask=None):
# mask = attention mask with shape [B, Tquery, Tkey] with 1 is for positions we want to attend, 0 for masked
if mask is not None:
if len(mask.shape) < 2:
raise ValueError("'mask' must have at least 2 dimensions")
if query.shape[-3] != mask.shape[-2]:
raise ValueError(
"mask's second to last dimension must be equal to "
"the number of elements in 'query'"
)
if key.shape[-3] != mask.shape[-1]:
raise ValueError(
"mask's last dimension must be equal to the number of elements in 'key'"
)
# apply mask
if mask is not None:
mask = tf.cast(mask, tf.float32)
# possibly expand on the head dimension so broadcasting works
if len(mask.shape) != len(logits.shape):
mask = tf.expand_dims(mask, -3)
logits += -10e9 * (1.0 - mask)
attn_coef = tf.nn.softmax(logits)
# attention dropout
attn_coef_dropout = self.dropout(attn_coef, training=training)
# attention * value
multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value)
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done.
output = tf.einsum("...NHI,HIO->...NO", multihead_output, self.projection_kernel)
if self.projection_bias is not None:
output += self.projection_bias
return output, attn_coef
def call(self, inputs, training=False, mask=None, **kwargs):
query, key, value = inputs
query, key, value = self.call_qkv(query, key, value, training=training)
# Scale dot-product, doing the division to either query or key
# instead of their product saves some computation
depth = tf.constant(self.head_size, dtype=tf.float32)
query /= tf.sqrt(depth)
# Calculate dot product attention
logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
output, attn_coef = self.call_attention(query, key, value, logits,
training=training, mask=mask)
if self.return_attn_coef:
return output, attn_coef
else:
return output
def compute_output_shape(self, input_shape):
num_value_features = (
input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1]
)
output_size = (
self.output_size if self.output_size is not None else num_value_features
)
output_shape = input_shape[0][:-1] + (output_size,)
if self.return_attn_coef:
num_query_elements = input_shape[0][-2]
num_key_elements = input_shape[1][-2]
attn_coef_shape = input_shape[0][:-2] + (
self.num_heads,
num_query_elements,
num_key_elements,
)
return output_shape, attn_coef_shape
else:
return output_shape
def get_config(self):
config = super().get_config()
config.update(
head_size=self.head_size,
num_heads=self.num_heads,
output_size=self.output_size,
dropout=self._droput_rate,
use_projection_bias=self.use_projection_bias,
return_attn_coef=self.return_attn_coef,
)
return config
class RelPositionMultiHeadAttention(MultiHeadAttention):
def __init__(self, kernel_sizes=None, strides=None, **kwargs):
super(RelPositionMultiHeadAttention, self).__init__(**kwargs)
def build(self, input_shape):
num_pos_features = input_shape[-1][-1]
input_max = (self.num_heads * self.head_size) ** -0.5
self.pos_kernel = self.add_weight(
name="pos_kernel",
shape=[self.num_heads, num_pos_features, self.head_size],
initializer=tf.keras.initializers.RandomUniform(minval=-input_max, maxval=input_max),
)
self.pos_bias_u = self.add_weight(
name="pos_bias_u",
shape=[self.num_heads, self.head_size],
initializer=tf.keras.initializers.Zeros(),
)
self.pos_bias_v = self.add_weight(
name="pos_bias_v",
shape=[self.num_heads, self.head_size],
initializer=tf.keras.initializers.Zeros(),
)
super(RelPositionMultiHeadAttention, self).build(input_shape[:-1])
@staticmethod
def relative_shift(x):
x_shape = tf.shape(x)
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]])
x = tf.reshape(x, [x_shape[0], x_shape[1], x_shape[3] + 1, x_shape[2]])
x = tf.reshape(x[:, :, 1:, :], x_shape)
return x
def call(self, inputs, training=False, mask=None, **kwargs):
query, key, value, pos = inputs
query, key, value = self.call_qkv(query, key, value, training=training)
pos = tf.einsum("...MI,HIO->...MHO", pos, self.pos_kernel)
query_with_u = query + self.pos_bias_u
query_with_v = query + self.pos_bias_v
logits_with_u = tf.einsum("...NHO,...MHO->...HNM", query_with_u, key)
logits_with_v = tf.einsum("...NHO,...MHO->...HNM", query_with_v, pos)
logits_with_v = self.relative_shift(logits_with_v)
logits = logits_with_u + logits_with_v[:, :, :, :tf.shape(logits_with_u)[3]]
depth = tf.constant(self.head_size, dtype=tf.float32)
logits /= tf.sqrt(depth)
output, attn_coef = self.call_attention(query, key, value, logits,
training=training, mask=mask)
if self.return_attn_coef:
return output, attn_coef
else:
return output
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import tensorflow as tf
from src.utils.shape_util import shape_list
class PositionalEncoding(tf.keras.layers.Layer):
'''
Same positional encoding method as NeMo library
'''
def __init__(self, d_model, max_len=5000, name="positional_encoding_nemo", **kwargs):
super().__init__(trainable=False, name=name, **kwargs)
self.max_len = max_len
positions = tf.expand_dims(tf.range(self.max_len - 1, -max_len, -1.0, dtype=tf.float32), axis=1)
pos_length = tf.shape(positions)[0]
pe = np.zeros([pos_length, d_model], 'float32')
div_term = np.exp(
tf.range(0, d_model, 2, dtype=tf.float32) * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = np.sin(positions * div_term)
pe[:, 1::2] = np.cos(positions * div_term)
pe = tf.convert_to_tensor(pe)
self.pe = tf.expand_dims(pe, 0)
def call(self, inputs, **kwargs):
# inputs shape [B, T, V]
_, length, dmodel = shape_list(inputs)
center_pos = tf.shape(self.pe)[1] // 2
start_pos = center_pos - length + 1
end_pos = center_pos + length
pos_emb = self.pe[:, start_pos:end_pos]
return tf.cast(pos_emb, dtype=inputs.dtype)
def get_config(self):
conf = super().get_config()
return conf.update({"max_len": self.max_len})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment