Commit d90f5580 authored by Haoliang Zhang's avatar Haoliang Zhang Committed by Yanhui Liang
Browse files

Update deep speech model with pure tensorflow API implementation (#4730)

* update

* update

* update

* update

* update
parent 37ba2304
# DeepSpeech2 Model
## Overview
This is an implementation of the [DeepSpeech2](https://arxiv.org/pdf/1512.02595.pdf) model. Current implementation is based on the code from the authors' [DeepSpeech code](https://github.com/PaddlePaddle/DeepSpeech) and the implementation in the [MLPerf Repo](https://github.com/mlperf/reference/tree/master/speech_recognition).
DeepSpeech2 is an end-to-end deep neural network for automatic speech
recognition (ASR). It consists of 2 convolutional layers, 5 bidirectional RNN
layers and a fully connected layer. The feature in use is linear spectrogram
extracted from audio input. The network uses Connectionist Temporal Classification [CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf) as the loss function.
## Dataset
The [OpenSLR LibriSpeech Corpus](http://www.openslr.org/12/) are used for model training and evaluation.
The training data is a combination of train-clean-100 and train-clean-360 (~130k
examples in total). The validation set is dev-clean which has 2.7K lines.
The download script will preprocess the data into three columns: wav_filename,
wav_filesize, transcript. data/dataset.py will parse the csv file and build a
tf.data.Dataset object to feed data. Within each epoch (except for the
first if sortagrad is enabled), the training data will be shuffled batch-wise.
## Running Code
### Configure Python path
Add the top-level /models folder to the Python path with the command:
```
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
### Install dependencies
First install shared dependencies before running the code. Issue the following command:
```
pip3 install -r requirements.txt
```
or
```
pip install -r requirements.txt
```
### Download and preprocess dataset
To download the dataset, issue the following command:
```
python data/download.py
```
Arguments:
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/librispeech_data`.
Use the `--help` or `-h` flag to get a full list of possible arguments.
### Train and evaluate model
To train and evaluate the model, issue the following command:
```
python deep_speech.py
```
Arguments:
* `--model_dir`: Directory to save model training checkpoints. By default, it is `/tmp/deep_speech_model/`.
* `--train_data_dir`: Directory of the training dataset.
* `--eval_data_dir`: Directory of the evaluation dataset.
* `--num_gpus`: Number of GPUs to use (specify -1 if you want to use all available GPUs).
There are other arguments about DeepSpeech2 model and training/evaluation process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.
...@@ -17,13 +17,14 @@ from __future__ import absolute_import ...@@ -17,13 +17,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import math
import multiprocessing import random
# pylint: disable=g-bad-import-order
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import soundfile
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
import data.featurizer as featurizer # pylint: disable=g-bad-import-order import data.featurizer as featurizer # pylint: disable=g-bad-import-order
...@@ -33,40 +34,37 @@ class AudioConfig(object): ...@@ -33,40 +34,37 @@ class AudioConfig(object):
def __init__(self, def __init__(self,
sample_rate, sample_rate,
frame_length, window_ms,
frame_step, stride_ms,
fft_length=None, normalize=False):
normalize=False,
spect_type="linear"):
"""Initialize the AudioConfig class. """Initialize the AudioConfig class.
Args: Args:
sample_rate: an integer denoting the sample rate of the input waveform. sample_rate: an integer denoting the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms. window_ms: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms. stride_ms: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
normalize: a boolean for whether apply normalization on the audio feature. normalize: a boolean for whether apply normalization on the audio feature.
spect_type: a string for the type of spectrogram to be extracted.
""" """
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.frame_length = frame_length self.window_ms = window_ms
self.frame_step = frame_step self.stride_ms = stride_ms
self.fft_length = fft_length
self.normalize = normalize self.normalize = normalize
self.spect_type = spect_type
class DatasetConfig(object): class DatasetConfig(object):
"""Config class for generating the DeepSpeechDataset.""" """Config class for generating the DeepSpeechDataset."""
def __init__(self, audio_config, data_path, vocab_file_path): def __init__(self, audio_config, data_path, vocab_file_path, sortagrad):
"""Initialize the configs for deep speech dataset. """Initialize the configs for deep speech dataset.
Args: Args:
audio_config: AudioConfig object specifying the audio-related configs. audio_config: AudioConfig object specifying the audio-related configs.
data_path: a string denoting the full path of a manifest file. data_path: a string denoting the full path of a manifest file.
vocab_file_path: a string specifying the vocabulary file path. vocab_file_path: a string specifying the vocabulary file path.
sortagrad: a boolean, if set to true, audio sequences will be fed by
increasing length in the first training epoch, which will
expedite network convergence.
Raises: Raises:
RuntimeError: file path not exist. RuntimeError: file path not exist.
...@@ -77,6 +75,7 @@ class DatasetConfig(object): ...@@ -77,6 +75,7 @@ class DatasetConfig(object):
assert tf.gfile.Exists(vocab_file_path) assert tf.gfile.Exists(vocab_file_path)
self.data_path = data_path self.data_path = data_path
self.vocab_file_path = vocab_file_path self.vocab_file_path = vocab_file_path
self.sortagrad = sortagrad
def _normalize_audio_feature(audio_feature): def _normalize_audio_feature(audio_feature):
...@@ -95,30 +94,23 @@ def _normalize_audio_feature(audio_feature): ...@@ -95,30 +94,23 @@ def _normalize_audio_feature(audio_feature):
return normalized return normalized
def _preprocess_audio( def _preprocess_audio(audio_file_path, audio_featurizer, normalize):
audio_file_path, audio_sample_rate, audio_featurizer, normalize): """Load the audio file and compute spectrogram feature."""
"""Load the audio file in memory and compute spectrogram feature.""" data, _ = soundfile.read(audio_file_path)
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == audio_sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = featurizer.compute_spectrogram_feature( feature = featurizer.compute_spectrogram_feature(
data, audio_featurizer.frame_length, audio_featurizer.frame_step, data, audio_featurizer.sample_rate, audio_featurizer.stride_ms,
audio_featurizer.fft_length) audio_featurizer.window_ms)
# Feature normalization
if normalize: if normalize:
feature = _normalize_audio_feature(feature) feature = _normalize_audio_feature(feature)
return feature
def _preprocess_transcript(transcript, token_to_index): # Adding Channel dimension for conv2D input.
"""Process transcript as label features.""" feature = np.expand_dims(feature, axis=2)
return featurizer.compute_label_feature(transcript, token_to_index) return feature
def _preprocess_data(dataset_config, audio_featurizer, token_to_index): def _preprocess_data(file_path):
"""Generate a list of waveform, transcript pair. """Generate a list of tuples (wav_filename, wav_filesize, transcript).
Each dataset file contains three columns: "wav_filename", "wav_filesize", Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example and "transcript". This function parses the csv file and stores each example
...@@ -127,42 +119,23 @@ def _preprocess_data(dataset_config, audio_featurizer, token_to_index): ...@@ -127,42 +119,23 @@ def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
mini-batch have similar length. mini-batch have similar length.
Args: Args:
dataset_config: an instance of DatasetConfig. file_path: a string specifying the csv file path for a dataset.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index
Returns: Returns:
features and labels array processed from the audio/text input. A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size.
""" """
tf.logging.info("Loading data set {}".format(file_path))
file_path = dataset_config.data_path
sample_rate = dataset_config.audio_config.sample_rate
normalize = dataset_config.audio_config.normalize
with tf.gfile.Open(file_path, "r") as f: with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines() lines = f.read().splitlines()
lines = [line.split("\t") for line in lines] # Skip the csv header in lines[0].
# Skip the csv header.
lines = lines[1:] lines = lines[1:]
# Sort input data by the length of waveform. # The metadata file is tab separated.
lines = [line.split("\t", 2) for line in lines]
# Sort input data by the length of audio sequence.
lines.sort(key=lambda item: int(item[1])) lines.sort(key=lambda item: int(item[1]))
# Use multiprocessing for feature/label extraction return [tuple(line) for line in lines]
num_cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cores)
features = pool.map(
functools.partial(
_preprocess_audio, audio_sample_rate=sample_rate,
audio_featurizer=audio_featurizer, normalize=normalize),
[line[0] for line in lines])
labels = pool.map(
functools.partial(
_preprocess_transcript, token_to_index=token_to_index),
[line[2] for line in lines])
pool.terminate()
return features, labels
class DeepSpeechDataset(object): class DeepSpeechDataset(object):
...@@ -178,22 +151,52 @@ class DeepSpeechDataset(object): ...@@ -178,22 +151,52 @@ class DeepSpeechDataset(object):
# Instantiate audio feature extractor. # Instantiate audio feature extractor.
self.audio_featurizer = featurizer.AudioFeaturizer( self.audio_featurizer = featurizer.AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate, sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length, window_ms=self.config.audio_config.window_ms,
frame_step=self.config.audio_config.frame_step, stride_ms=self.config.audio_config.stride_ms)
fft_length=self.config.audio_config.fft_length)
# Instantiate text feature extractor. # Instantiate text feature extractor.
self.text_featurizer = featurizer.TextFeaturizer( self.text_featurizer = featurizer.TextFeaturizer(
vocab_file=self.config.vocab_file_path) vocab_file=self.config.vocab_file_path)
self.speech_labels = self.text_featurizer.speech_labels self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = _preprocess_data( self.entries = _preprocess_data(self.config.data_path)
self.config, # The generated spectrogram will have 161 feature bins.
self.audio_featurizer, self.num_feature_bins = 161
self.text_featurizer.token_to_idx
)
self.num_feature_bins = ( def batch_wise_dataset_shuffle(entries, epoch_index, sortagrad, batch_size):
self.features[0].shape[1] if len(self.features) else None) """Batch-wise shuffling of the data entries.
Each data entry is in the format of (audio_file, file_size, transcript).
If epoch_index is 0 and sortagrad is true, we don't perform shuffling and
return entries in sorted file_size order. Otherwise, do batch_wise shuffling.
Args:
entries: a list of data entries.
epoch_index: an integer of epoch index
sortagrad: a boolean to control whether sorting the audio in the first
training epoch.
batch_size: an integer for the batch size.
Returns:
The shuffled data entries.
"""
shuffled_entries = []
if epoch_index == 0 and sortagrad:
# No need to shuffle.
shuffled_entries = entries
else:
# Shuffle entries batch-wise.
max_buckets = int(math.floor(len(entries) / batch_size))
total_buckets = [i for i in xrange(max_buckets)]
random.shuffle(total_buckets)
shuffled_entries = []
for i in total_buckets:
shuffled_entries.extend(entries[i * batch_size : (i + 1) * batch_size])
# If the last batch doesn't contain enough batch_size examples,
# just append it to the shuffled_entries.
shuffled_entries.extend(entries[max_buckets * batch_size:])
return shuffled_entries
def input_fn(batch_size, deep_speech_dataset, repeat=1): def input_fn(batch_size, deep_speech_dataset, repeat=1):
...@@ -207,49 +210,66 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1): ...@@ -207,49 +210,66 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1):
Returns: Returns:
a tf.data.Dataset object for model to consume. a tf.data.Dataset object for model to consume.
""" """
features = deep_speech_dataset.features # Dataset properties
labels = deep_speech_dataset.labels data_entries = deep_speech_dataset.entries
num_feature_bins = deep_speech_dataset.num_feature_bins num_feature_bins = deep_speech_dataset.num_feature_bins
audio_featurizer = deep_speech_dataset.audio_featurizer
feature_normalize = deep_speech_dataset.config.audio_config.normalize
text_featurizer = deep_speech_dataset.text_featurizer
def _gen_data(): def _gen_data():
for i in xrange(len(features)): """Dataset generator function."""
feature = np.expand_dims(features[i], axis=2) for audio_file, _, transcript in data_entries:
input_length = [features[i].shape[0]] features = _preprocess_audio(
label_length = [len(labels[i])] audio_file, audio_featurizer, feature_normalize)
yield { labels = featurizer.compute_label_feature(
"features": feature, transcript, text_featurizer.token_to_index)
"labels": labels[i], input_length = [features.shape[0]]
"input_length": input_length, label_length = [len(labels)]
"label_length": label_length # Yield a tuple of (features, labels) where features is a dict containing
} # all info about the actual data features.
yield (
{
"features": features,
"input_length": input_length,
"label_length": label_length
},
labels)
dataset = tf.data.Dataset.from_generator( dataset = tf.data.Dataset.from_generator(
_gen_data, _gen_data,
output_types={ output_types=(
"features": tf.float32, {
"labels": tf.int32, "features": tf.float32,
"input_length": tf.int32, "input_length": tf.int32,
"label_length": tf.int32 "label_length": tf.int32
}, },
output_shapes={ tf.int32),
"features": tf.TensorShape([None, num_feature_bins, 1]), output_shapes=(
"labels": tf.TensorShape([None]), {
"input_length": tf.TensorShape([1]), "features": tf.TensorShape([None, num_feature_bins, 1]),
"label_length": tf.TensorShape([1]) "input_length": tf.TensorShape([1]),
}) "label_length": tf.TensorShape([1])
},
tf.TensorShape([None]))
)
# Repeat and batch the dataset # Repeat and batch the dataset
dataset = dataset.repeat(repeat) dataset = dataset.repeat(repeat)
# Padding the features to its max length dimensions. # Padding the features to its max length dimensions.
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
batch_size=batch_size, batch_size=batch_size,
padded_shapes={ padded_shapes=(
"features": tf.TensorShape([None, num_feature_bins, 1]), {
"labels": tf.TensorShape([None]), "features": tf.TensorShape([None, num_feature_bins, 1]),
"input_length": tf.TensorShape([1]), "input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1]) "label_length": tf.TensorShape([1])
}) },
tf.TensorShape([None]))
)
# Prefetch to improve speed of input pipeline. # Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(1) dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset return dataset
...@@ -19,20 +19,51 @@ from __future__ import print_function ...@@ -19,20 +19,51 @@ from __future__ import print_function
import codecs import codecs
import numpy as np import numpy as np
from scipy import signal
def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length): def compute_spectrogram_feature(samples, sample_rate, stride_ms=10.0,
"""Compute the spectrograms for the input waveform.""" window_ms=20.0, max_freq=None, eps=1e-14):
_, _, stft = signal.stft( """Compute the spectrograms for the input samples(waveforms).
waveform,
nperseg=frame_length,
noverlap=frame_step,
nfft=fft_length)
# Perform transpose to set its shape as [time_steps, feature_num_bins] More about spectrogram computation, please refer to:
spectrogram = np.transpose(np.absolute(stft), (1, 0)) https://en.wikipedia.org/wiki/Short-time_Fourier_transform.
return spectrogram """
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
# Extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# Window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)
fft = fft**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# Prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
# Compute spectrogram feature
ind = np.where(freqs <= max_freq)[0][-1] + 1
specgram = np.log(fft[:ind, :] + eps)
return np.transpose(specgram, (1, 0))
class AudioFeaturizer(object): class AudioFeaturizer(object):
...@@ -40,21 +71,18 @@ class AudioFeaturizer(object): ...@@ -40,21 +71,18 @@ class AudioFeaturizer(object):
def __init__(self, def __init__(self,
sample_rate=16000, sample_rate=16000,
frame_length=25, window_ms=20.0,
frame_step=10, stride_ms=10.0):
fft_length=None):
"""Initialize the audio featurizer class according to the configs. """Initialize the audio featurizer class according to the configs.
Args: Args:
sample_rate: an integer specifying the sample rate of the input waveform. sample_rate: an integer specifying the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms. window_ms: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms. stride_ms: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
""" """
self.frame_length = int(sample_rate * frame_length / 1e3) self.sample_rate = sample_rate
self.frame_step = int(sample_rate * frame_step / 1e3) self.window_ms = window_ms
self.fft_length = fft_length if fft_length else int(2**(np.ceil( self.stride_ms = stride_ms
np.log2(self.frame_length))))
def compute_label_feature(text, token_to_idx): def compute_label_feature(text, token_to_idx):
...@@ -75,16 +103,16 @@ class TextFeaturizer(object): ...@@ -75,16 +103,16 @@ class TextFeaturizer(object):
lines = [] lines = []
with codecs.open(vocab_file, "r", "utf-8") as fin: with codecs.open(vocab_file, "r", "utf-8") as fin:
lines.extend(fin.readlines()) lines.extend(fin.readlines())
self.token_to_idx = {} self.token_to_index = {}
self.idx_to_token = {} self.index_to_token = {}
self.speech_labels = "" self.speech_labels = ""
idx = 0 index = 0
for line in lines: for line in lines:
line = line[:-1] # Strip the '\n' char. line = line[:-1] # Strip the '\n' char.
if line.startswith("#"): if line.startswith("#"):
# Skip from reading comment line. # Skip from reading comment line.
continue continue
self.token_to_idx[line] = idx self.token_to_index[line] = index
self.idx_to_token[idx] = line self.index_to_token[index] = line
self.speech_labels += line self.speech_labels += line
idx += 1 index += 1
# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which # List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which
# will be ignored by the parser. # will be ignored by the parser.
# begin of vocabulary # begin of vocabulary
a a
b b
c c
...@@ -28,6 +29,5 @@ x ...@@ -28,6 +29,5 @@ x
y y
z z
' '
- -
# end of vocabulary # end of vocabulary
...@@ -18,189 +18,78 @@ from __future__ import absolute_import ...@@ -18,189 +18,78 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import itertools
from nltk.metrics import distance from nltk.metrics import distance
from six.moves import xrange import numpy as np
import tensorflow as tf
class DeepSpeechDecoder(object): class DeepSpeechDecoder(object):
"""Basic decoder class from which all other decoders inherit. """Greedy decoder implementation for Deep Speech model."""
Implements several helper functions. Subclasses should implement the decode()
method.
"""
def __init__(self, labels, blank_index=28, space_index=27): def __init__(self, labels, blank_index=28):
"""Decoder initialization. """Decoder initialization.
Arguments: Arguments:
labels (string): mapping from integers to characters. labels: a string specifying the speech labels for the decoder to use.
blank_index (int, optional): index for the blank '_' character. blank_index: an integer specifying index for the blank character.
Defaults to 0.
space_index (int, optional): index for the space ' ' character.
Defaults to 28. Defaults to 28.
""" """
# e.g. labels = "[a-z]' _" # e.g. labels = "[a-z]' _"
self.labels = labels self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index self.blank_index = blank_index
self.space_index = space_index self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
def convert_to_strings(self, sequences, sizes=None):
"""Given a list of numeric sequences, returns the corresponding strings."""
strings = []
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string = self._convert_to_string(sequences[x], seq_len)
strings.append(string)
return strings
def _convert_to_string(self, sequence, sizes):
return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)])
def process_strings(self, sequences, remove_repetitions=False):
"""Process strings.
Given a list of strings, removes blanks and replace space character with def convert_to_string(self, sequence):
space. Option to remove repetitions (e.g. 'abbca' -> 'abca'). """Convert a sequence of indexes into corresponding string."""
return ''.join([self.int_to_char[i] for i in sequence])
Arguments: def wer(self, decode, target):
sequences: list of 1-d array of integers
remove_repetitions (boolean, optional): If true, repeating characters
are removed. Defaults to False.
Returns:
The processed string.
"""
processed_strings = []
for sequence in sequences:
string = self.process_string(remove_repetitions, sequence).strip()
processed_strings.append(string)
return processed_strings
def process_string(self, remove_repetitions, sequence):
"""Process each given sequence."""
seq_string = ''
for i, char in enumerate(sequence):
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true,
# skip.
if remove_repetitions and i != 0 and char == sequence[i - 1]:
pass
elif char == self.labels[self.space_index]:
seq_string += ' '
else:
seq_string += char
return seq_string
def wer(self, output, target):
"""Computes the Word Error Rate (WER). """Computes the Word Error Rate (WER).
WER is defined as the edit distance between the two provided sentences after WER is defined as the edit distance between the two provided sentences after
tokenizing to words. tokenizing to words.
Args: Args:
output: string of the decoded output. decode: string of the decoded output.
target: a string for the true transcript. target: a string for the ground truth label.
Returns: Returns:
A float number for the WER of the current sentence pair. A float number for the WER of the current decode-target pair.
""" """
# Map each word to a new char. # Map each word to a new char.
words = set(output.split() + target.split()) words = set(decode.split() + target.split())
word2char = dict(zip(words, range(len(words)))) word2char = dict(zip(words, range(len(words))))
new_output = [chr(word2char[w]) for w in output.split()] new_decode = [chr(word2char[w]) for w in decode.split()]
new_target = [chr(word2char[w]) for w in target.split()] new_target = [chr(word2char[w]) for w in target.split()]
return distance.edit_distance(''.join(new_output), ''.join(new_target)) return distance.edit_distance(''.join(new_decode), ''.join(new_target))
def cer(self, output, target): def cer(self, decode, target):
"""Computes the Character Error Rate (CER). """Computes the Character Error Rate (CER).
CER is defined as the edit distance between the given strings. CER is defined as the edit distance between the two given strings.
Args: Args:
output: a string of the decoded output. decode: a string of the decoded output.
target: a string for the ground truth transcript. target: a string for the ground truth label.
Returns: Returns:
A float number denoting the CER for the current sentence pair. A float number denoting the CER for the current sentence pair.
""" """
return distance.edit_distance(output, target) return distance.edit_distance(decode, target)
def batch_wer(self, decoded_output, targets): def decode(self, logits):
"""Compute the aggregate WER for each batch. """Decode the best guess from logits using greedy algorithm."""
# Choose the class with maximimum probability.
Args: best = list(np.argmax(logits, axis=1))
decoded_output: 2d array of integers for the decoded output of a batch. # Merge repeated chars.
targets: 2d array of integers for the labels of a batch. merge = [k for k, _ in itertools.groupby(best)]
# Remove the blank index in the decoded sequence.
Returns: merge_remove_blank = []
A float number for the aggregated WER for the current batch output. for k in merge:
""" if k != self.blank_index:
# Convert numeric representation to string. merge_remove_blank.append(k)
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings( return self.convert_to_string(merge_remove_blank)
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
wer = 0
for i in xrange(len(decoded_strings)):
wer += self.wer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i].split()))
return wer
def batch_cer(self, decoded_output, targets):
"""Compute the aggregate CER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated CER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings(
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
cer = 0
for i in xrange(len(decoded_strings)):
cer += self.cer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i]))
return cer
def decode(self, sequences, sizes=None):
"""Perform sequence decoding.
Given a matrix of character probabilities, returns the decoder's best guess
of the transcription.
Arguments:
sequences: 2D array of character probabilities, where sequences[c, t] is
the probability of character c at time t.
sizes(optional): Size of each sequence in the mini-batch.
Returns:
string: sequence of the model's best guess for the transcription.
"""
strings = self.convert_to_strings(sequences, sizes)
return self.process_strings(strings, remove_repetitions=True)
class GreedyDecoder(DeepSpeechDecoder):
"""Greedy decoder."""
def decode(self, logits, seq_len):
# Reshape to [max_time, batch_size, num_classes]
logits = tf.transpose(logits, (1, 0, 2))
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
decoded_dense = tf.Session().run(tf.sparse_to_dense(
decoded[0].indices, decoded[0].dense_shape, decoded[0].values))
result = self.convert_to_strings(decoded_dense)
return self.process_strings(result, remove_repetitions=True), decoded_dense
...@@ -41,8 +41,50 @@ _WER_KEY = "WER" ...@@ -41,8 +41,50 @@ _WER_KEY = "WER"
_CER_KEY = "CER" _CER_KEY = "CER"
def evaluate_model( def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
estimator, batch_size, speech_labels, targets, input_fn_eval): """Computes the time_steps/ctc_input_length after convolution.
Suppose that the original feature contains two parts:
1) Real spectrogram signals, spanning input_length steps.
2) Padded part with all 0s.
The total length of those two parts is denoted as max_time_steps, which is
the padded length of the current batch. After convolution layers, the time
steps of a spectrogram feature will be decreased. As we know the percentage
of its original length within the entire length, we can compute the time steps
for the signal after conv as follows (using ctc_input_length to denote):
ctc_input_length = (input_length / max_time_steps) * output_length_of_conv.
This length is then fed into ctc loss function to compute loss.
Args:
max_time_steps: max_time_steps for the batch, after padding.
ctc_time_steps: number of timesteps after convolution.
input_length: actual length of the original spectrogram, without padding.
Returns:
the ctc_input_length after convolution layer.
"""
ctc_input_length = tf.to_float(tf.multiply(
input_length, ctc_time_steps))
return tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
def ctc_loss(label_length, ctc_input_length, labels, logits):
"""Computes the ctc loss for the current batch of predictions."""
label_length = tf.to_int32(tf.squeeze(label_length))
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length))
sparse_labels = tf.to_int32(
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length))
y_pred = tf.log(tf.transpose(
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon())
return tf.expand_dims(
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred,
sequence_length=ctc_input_length),
axis=1)
def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
"""Evaluate the model performance using WER anc CER as metrics. """Evaluate the model performance using WER anc CER as metrics.
WER: Word Error Rate WER: Word Error Rate
...@@ -50,44 +92,34 @@ def evaluate_model( ...@@ -50,44 +92,34 @@ def evaluate_model(
Args: Args:
estimator: estimator to evaluate. estimator: estimator to evaluate.
batch_size: size of a mini-batch.
speech_labels: a string specifying all the character in the vocabulary. speech_labels: a string specifying all the character in the vocabulary.
targets: a list of list of integers for the featurized transcript. entries: a list of data entries (audio_file, file_size, transcript) for the
given dataset.
input_fn_eval: data input function for evaluation. input_fn_eval: data input function for evaluation.
Returns: Returns:
Evaluation result containing 'wer' and 'cer' as two metrics. Evaluation result containing 'wer' and 'cer' as two metrics.
""" """
# Get predictions # Get predictions
predictions = estimator.predict( predictions = estimator.predict(input_fn=input_fn_eval)
input_fn=input_fn_eval, yield_single_examples=False)
y_preds = [] # Get probabilities of each predicted class
input_lengths = [] probs = [pred["probabilities"] for pred in predictions]
for p in predictions:
y_preds.append(p["y_pred"])
input_lengths.append(p["ctc_input_length"])
num_of_examples = len(targets) num_of_examples = len(probs)
total_wer, total_cer = 0, 0 targets = [entry[2] for entry in entries] # The ground truth transcript
greedy_decoder = decoder.GreedyDecoder(speech_labels)
for i in range(len(y_preds)):
# Compute the CER and WER for the current batch,
# and aggregate to total_cer, total_wer.
y_pred_tensor = tf.convert_to_tensor(y_preds[i])
batch_targets = targets[i * batch_size : (i + 1) * batch_size]
seq_len = tf.squeeze(input_lengths[i], axis=1)
# Perform decoding
_, decoded_output = greedy_decoder.decode(
y_pred_tensor, seq_len)
total_wer, total_cer = 0, 0
greedy_decoder = decoder.DeepSpeechDecoder(speech_labels)
for i in range(num_of_examples):
# Decode string.
decoded_str = greedy_decoder.decode(probs[i])
# Compute CER. # Compute CER.
batch_cer = greedy_decoder.batch_cer(decoded_output, batch_targets) total_cer += greedy_decoder.cer(decoded_str, targets[i]) / float(
total_cer += batch_cer len(targets[i]))
# Compute WER. # Compute WER.
batch_wer = greedy_decoder.batch_wer(decoded_output, batch_targets) total_wer += greedy_decoder.wer(decoded_str, targets[i]) / float(
total_wer += batch_wer len(targets[i].split()))
# Get mean value # Get mean value
total_cer /= num_of_examples total_cer /= num_of_examples
...@@ -103,45 +135,76 @@ def evaluate_model( ...@@ -103,45 +135,76 @@ def evaluate_model(
return eval_results return eval_results
def convert_keras_to_estimator(keras_model, num_gpus): def model_fn(features, labels, mode, params):
"""Configure and convert keras model to Estimator. """Define model function for deep speech model.
Args: Args:
keras_model: A Keras model object. features: a dictionary of input_data features. It includes the data
num_gpus: An integer, the number of GPUs. input_length, label_length and the spectrogram features.
labels: a list of labels for the input data.
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`.
params: a dict of hyper parameters to be passed to model_fn.
Returns: Returns:
estimator: The converted Estimator. EstimatorSpec parameterized according to the input params and the
current mode.
""" """
# keras optimizer is not compatible with distribution strategy. num_classes = params["num_classes"]
# Use tf optimizer instead input_length = features["input_length"]
optimizer = tf.train.MomentumOptimizer( label_length = features["label_length"]
learning_rate=flags_obj.learning_rate, momentum=flags_obj.momentum, features = features["features"]
use_nesterov=True)
# Create DeepSpeech2 model.
# ctc_loss is wrapped as a Lambda layer in the model. model = deep_speech_model.DeepSpeech2(
keras_model.compile( flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
optimizer=optimizer, loss={"ctc_loss": lambda y_true, y_pred: y_pred}) flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
num_classes, flags_obj.use_bias)
distribution_strategy = distribution_utils.get_distribution_strategy(
num_gpus) if mode == tf.estimator.ModeKeys.PREDICT:
run_config = tf.estimator.RunConfig( logits = model(features, training=False)
train_distribute=distribution_strategy) predictions = {
"classes": tf.argmax(logits, axis=2),
estimator = tf.keras.estimator.model_to_estimator( "probabilities": tf.nn.softmax(logits),
keras_model=keras_model, model_dir=flags_obj.model_dir, config=run_config) "logits": logits
}
return estimator return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions)
# In training mode.
logits = model(features, training=True)
probs = tf.nn.softmax(logits)
ctc_input_length = compute_length_after_conv(
tf.shape(features)[1], tf.shape(probs)[1], input_length)
# Compute CTC loss
loss = tf.reduce_mean(ctc_loss(
label_length, ctc_input_length, labels, probs))
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
global_step = tf.train.get_or_create_global_step()
minimize_op = optimizer.minimize(loss, global_step=global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Create the train_op that groups both minimize_ops and update_ops
train_op = tf.group(minimize_op, update_ops)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
def generate_dataset(data_dir): def generate_dataset(data_dir):
"""Generate a speech dataset.""" """Generate a speech dataset."""
audio_conf = dataset.AudioConfig( audio_conf = dataset.AudioConfig(sample_rate=flags_obj.sample_rate,
flags_obj.sample_rate, flags_obj.frame_length, flags_obj.frame_step) window_ms=flags_obj.window_ms,
stride_ms=flags_obj.stride_ms,
normalize=True)
train_data_conf = dataset.DatasetConfig( train_data_conf = dataset.DatasetConfig(
audio_conf, audio_conf,
data_dir, data_dir,
flags_obj.vocabulary_file, flags_obj.vocabulary_file,
flags_obj.sortagrad
) )
speech_dataset = dataset.DeepSpeechDataset(train_data_conf) speech_dataset = dataset.DeepSpeechDataset(train_data_conf)
return speech_dataset return speech_dataset
...@@ -150,30 +213,27 @@ def generate_dataset(data_dir): ...@@ -150,30 +213,27 @@ def generate_dataset(data_dir):
def run_deep_speech(_): def run_deep_speech(_):
"""Run deep speech training and eval loop.""" """Run deep speech training and eval loop."""
# Data preprocessing # Data preprocessing
# The file name of training and test dataset
tf.logging.info("Data preprocessing...") tf.logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir) train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
# Number of label classes. Label string is "[a-z]' -" # Number of label classes. Label string is "[a-z]' -"
num_classes = len(train_speech_dataset.speech_labels) num_classes = len(train_speech_dataset.speech_labels)
# Input shape of each data example: # Use distribution strategy for multi-gpu training
# [time_steps (T), feature_bins(F), channel(C)]
# Channel is set as 1 by default.
input_shape = (None, train_speech_dataset.num_feature_bins, 1)
# Create deep speech model and convert it to Estimator
tf.logging.info("Creating Estimator from Keras model...")
keras_model = deep_speech_model.DeepSpeech(
input_shape, flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
flags_obj.rnn_activation, num_classes, flags_obj.use_bias)
# Convert to estimator
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
estimator = convert_keras_to_estimator(keras_model, num_gpus) distribution_strategy = distribution_utils.get_distribution_strategy(num_gpus)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=flags_obj.model_dir,
config=run_config,
params={
"num_classes": num_classes,
}
)
# Benchmark logging # Benchmark logging
run_params = { run_params = {
...@@ -181,7 +241,6 @@ def run_deep_speech(_): ...@@ -181,7 +241,6 @@ def run_deep_speech(_):
"train_epochs": flags_obj.train_epochs, "train_epochs": flags_obj.train_epochs,
"rnn_hidden_size": flags_obj.rnn_hidden_size, "rnn_hidden_size": flags_obj.rnn_hidden_size,
"rnn_hidden_layers": flags_obj.rnn_hidden_layers, "rnn_hidden_layers": flags_obj.rnn_hidden_layers,
"rnn_activation": flags_obj.rnn_activation,
"rnn_type": flags_obj.rnn_type, "rnn_type": flags_obj.rnn_type,
"is_bidirectional": flags_obj.is_bidirectional, "is_bidirectional": flags_obj.is_bidirectional,
"use_bias": flags_obj.use_bias "use_bias": flags_obj.use_bias
...@@ -194,6 +253,7 @@ def run_deep_speech(_): ...@@ -194,6 +253,7 @@ def run_deep_speech(_):
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, flags_obj.hooks,
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size) batch_size=flags_obj.batch_size)
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
...@@ -213,14 +273,19 @@ def run_deep_speech(_): ...@@ -213,14 +273,19 @@ def run_deep_speech(_):
tf.logging.info("Starting a training cycle: %d/%d", tf.logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle) cycle_index + 1, total_training_cycle)
# Perform batch_wise dataset shuffling
train_speech_dataset.entries = dataset.batch_wise_dataset_shuffle(
train_speech_dataset.entries, cycle_index, flags_obj.sortagrad,
flags_obj.batch_size)
estimator.train(input_fn=input_fn_train, hooks=train_hooks) estimator.train(input_fn=input_fn_train, hooks=train_hooks)
# Evaluation # Evaluation
tf.logging.info("Starting to evaluate...") tf.logging.info("Starting to evaluate...")
eval_results = evaluate_model( eval_results = evaluate_model(
estimator, flags_obj.batch_size, eval_speech_dataset.speech_labels, estimator, eval_speech_dataset.speech_labels,
eval_speech_dataset.labels, input_fn_eval) eval_speech_dataset.entries, input_fn_eval)
# Log the WER and CER results. # Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
...@@ -233,9 +298,6 @@ def run_deep_speech(_): ...@@ -233,9 +298,6 @@ def run_deep_speech(_):
flags_obj.wer_threshold, eval_results[_WER_KEY]): flags_obj.wer_threshold, eval_results[_WER_KEY]):
break break
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
def define_deep_speech_flags(): def define_deep_speech_flags():
"""Add flags for run_deep_speech.""" """Add flags for run_deep_speech."""
...@@ -257,8 +319,8 @@ def define_deep_speech_flags(): ...@@ -257,8 +319,8 @@ def define_deep_speech_flags():
flags_core.set_defaults( flags_core.set_defaults(
model_dir="/tmp/deep_speech_model/", model_dir="/tmp/deep_speech_model/",
export_dir="/tmp/deep_speech_saved_model/", export_dir="/tmp/deep_speech_saved_model/",
train_epochs=2, train_epochs=200,
batch_size=4, batch_size=128,
hooks="") hooks="")
# Deep speech flags # Deep speech flags
...@@ -272,16 +334,22 @@ def define_deep_speech_flags(): ...@@ -272,16 +334,22 @@ def define_deep_speech_flags():
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv", default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
help=flags_core.help_wrap("The csv file path of evaluation dataset.")) help=flags_core.help_wrap("The csv file path of evaluation dataset."))
flags.DEFINE_bool(
name="sortagrad", default=True,
help=flags_core.help_wrap(
"If true, sort examples by audio length and perform no "
"batch_wise shuffling for the first epoch."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="sample_rate", default=16000, name="sample_rate", default=16000,
help=flags_core.help_wrap("The sample rate for audio.")) help=flags_core.help_wrap("The sample rate for audio."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="frame_length", default=25, name="window_ms", default=20,
help=flags_core.help_wrap("The frame length for spectrogram.")) help=flags_core.help_wrap("The frame length for spectrogram."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="frame_step", default=10, name="stride_ms", default=10,
help=flags_core.help_wrap("The frame step.")) help=flags_core.help_wrap("The frame step."))
flags.DEFINE_string( flags.DEFINE_string(
...@@ -290,11 +358,11 @@ def define_deep_speech_flags(): ...@@ -290,11 +358,11 @@ def define_deep_speech_flags():
# RNN related flags # RNN related flags
flags.DEFINE_integer( flags.DEFINE_integer(
name="rnn_hidden_size", default=256, name="rnn_hidden_size", default=800,
help=flags_core.help_wrap("The hidden size of RNNs.")) help=flags_core.help_wrap("The hidden size of RNNs."))
flags.DEFINE_integer( flags.DEFINE_integer(
name="rnn_hidden_layers", default=3, name="rnn_hidden_layers", default=5,
help=flags_core.help_wrap("The number of RNN layers.")) help=flags_core.help_wrap("The number of RNN layers."))
flags.DEFINE_bool( flags.DEFINE_bool(
...@@ -311,20 +379,11 @@ def define_deep_speech_flags(): ...@@ -311,20 +379,11 @@ def define_deep_speech_flags():
case_sensitive=False, case_sensitive=False,
help=flags_core.help_wrap("Type of RNN cell.")) help=flags_core.help_wrap("Type of RNN cell."))
flags.DEFINE_enum(
name="rnn_activation", default="tanh",
enum_values=["tanh", "relu"], case_sensitive=False,
help=flags_core.help_wrap("Type of the activation within RNN."))
# Training related flags # Training related flags
flags.DEFINE_float( flags.DEFINE_float(
name="learning_rate", default=0.0003, name="learning_rate", default=5e-4,
help=flags_core.help_wrap("The initial learning rate.")) help=flags_core.help_wrap("The initial learning rate."))
flags.DEFINE_float(
name="momentum", default=0.9,
help=flags_core.help_wrap("Momentum to accelerate SGD optimizer."))
# Evaluation metrics threshold # Evaluation metrics threshold
flags.DEFINE_float( flags.DEFINE_float(
name="wer_threshold", default=None, name="wer_threshold", default=None,
...@@ -345,3 +404,4 @@ if __name__ == "__main__": ...@@ -345,3 +404,4 @@ if __name__ == "__main__":
define_deep_speech_flags() define_deep_speech_flags()
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
absl_app.run(main) absl_app.run(main)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Network structure for DeepSpeech model.""" """Network structure for DeepSpeech2 model."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -20,175 +20,166 @@ from __future__ import print_function ...@@ -20,175 +20,166 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
# Supported rnn cells # Supported rnn cells.
SUPPORTED_RNNS = { SUPPORTED_RNNS = {
"lstm": tf.keras.layers.LSTM, "lstm": tf.nn.rnn_cell.BasicLSTMCell,
"rnn": tf.keras.layers.SimpleRNN, "rnn": tf.nn.rnn_cell.RNNCell,
"gru": tf.keras.layers.GRU, "gru": tf.nn.rnn_cell.GRUCell,
} }
# Parameters for batch normalization # Parameters for batch normalization.
_MOMENTUM = 0.1 _BATCH_NORM_EPSILON = 1e-5
_EPSILON = 1e-05 _BATCH_NORM_DECAY = 0.997
# Filters of convolution layer
_CONV_FILTERS = 32
def _conv_bn_layer(cnn_input, filters, kernel_size, strides, layer_id):
"""2D convolution + batch normalization layer. def batch_norm(inputs, training):
"""Batch normalization layer.
Note that the momentum to use will affect validation accuracy over time.
Batch norm has different behaviors during training/evaluation. With a large
momentum, the model takes longer to get a near-accurate estimation of the
moving mean/variance over the entire training dataset, which means we need
more iterations to see good evaluation results. If the training data is evenly
distributed over the feature space, we can also try setting a smaller momentum
(such as 0.1) to get good evaluation result sooner.
Args:
inputs: input data for batch norm layer.
training: a boolean to indicate if it is in training stage.
Returns:
tensor output from batch norm layer.
"""
return tf.layers.batch_normalization(
inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
fused=True, training=training)
def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
training):
"""Defines 2D constitutional + batch normalization layer.
Args: Args:
cnn_input: input data for convolution layer. inputs: input data for convolution layer.
padding: padding to be applied before convolution layer.
filters: an integer, number of output filters in the convolution. filters: an integer, number of output filters in the convolution.
kernel_size: a tuple specifying the height and width of the 2D convolution kernel_size: a tuple specifying the height and width of the 2D convolution
window. window.
strides: a tuple specifying the stride length of the convolution. strides: a tuple specifying the stride length of the convolution.
layer_id: an integer specifying the layer index. layer_id: an integer specifying the layer index.
training: a boolean to indicate which stage we are in (training/eval).
Returns: Returns:
tensor output from the current layer. tensor output from the current layer.
""" """
output = tf.keras.layers.Conv2D( # Perform symmetric padding on the feature dimension of time_step
filters=filters, kernel_size=kernel_size, strides=strides, padding="same", # This step is required to avoid issues when RNN output sequence is shorter
activation="linear", name="cnn_{}".format(layer_id))(cnn_input) # than the label length.
output = tf.keras.layers.BatchNormalization( inputs = tf.pad(
momentum=_MOMENTUM, epsilon=_EPSILON)(output) inputs,
return output [[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])
inputs = tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
def _rnn_layer(input_data, rnn_cell, rnn_hidden_size, layer_id, rnn_activation, padding="valid", use_bias=False, activation=tf.nn.relu6,
is_batch_norm, is_bidirectional): name="cnn_{}".format(layer_id))
return batch_norm(inputs, training)
def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
is_bidirectional, training):
"""Defines a batch normalization + rnn layer. """Defines a batch normalization + rnn layer.
Args: Args:
input_data: input tensors for the current layer. inputs: input tensors for the current layer.
rnn_cell: RNN cell instance to use. rnn_cell: RNN cell instance to use.
rnn_hidden_size: an integer for the dimensionality of the rnn output space. rnn_hidden_size: an integer for the dimensionality of the rnn output space.
layer_id: an integer for the index of current layer. layer_id: an integer for the index of current layer.
rnn_activation: activation function to use.
is_batch_norm: a boolean specifying whether to perform batch normalization is_batch_norm: a boolean specifying whether to perform batch normalization
on input states. on input states.
is_bidirectional: a boolean specifying whether the rnn layer is is_bidirectional: a boolean specifying whether the rnn layer is
bi-directional. bi-directional.
training: a boolean to indicate which stage we are in (training/eval).
Returns: Returns:
tensor output for the current layer. tensor output for the current layer.
""" """
if is_batch_norm: if is_batch_norm:
input_data = tf.keras.layers.BatchNormalization( inputs = batch_norm(inputs, training)
momentum=_MOMENTUM, epsilon=_EPSILON)(input_data)
rnn_layer = rnn_cell(
rnn_hidden_size, activation=rnn_activation, return_sequences=True,
name="rnn_{}".format(layer_id))
if is_bidirectional:
rnn_layer = tf.keras.layers.Bidirectional(rnn_layer, merge_mode="sum")
return rnn_layer(input_data)
def _ctc_lambda_func(args): # Construct forward/backward RNN cells.
"""Compute ctc loss.""" fw_cell = rnn_cell(num_units=rnn_hidden_size,
# py2 needs explicit tf import for keras Lambda layer name="rnn_fw_{}".format(layer_id))
import tensorflow as tf bw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_bw_{}".format(layer_id))
y_pred, labels, input_length, label_length = args if is_bidirectional:
return tf.keras.backend.ctc_batch_cost( outputs, _ = tf.nn.bidirectional_dynamic_rnn(
labels, y_pred, input_length, label_length) cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32,
swap_memory=True)
rnn_outputs = tf.concat(outputs, -1)
def _calc_ctc_input_length(args): else:
"""Compute the actual input length after convolution for ctc_loss function. rnn_outputs = tf.nn.dynamic_rnn(
fw_cell, inputs, dtype=tf.float32, swap_memory=True)
Basically, we need to know the scaled input_length after conv layers.
new_input_length = old_input_length * ctc_time_steps / max_time_steps
Args:
args: the input args to compute ctc input length.
Returns:
ctc_input_length, which is required for ctc loss calculation.
"""
# py2 needs explicit tf import for keras Lambda layer
import tensorflow as tf
input_length, input_data, y_pred = args return rnn_outputs
max_time_steps = tf.shape(input_data)[1]
ctc_time_steps = tf.shape(y_pred)[1]
ctc_input_length = tf.multiply(
tf.to_float(input_length), tf.to_float(ctc_time_steps))
ctc_input_length = tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
return ctc_input_length
class DeepSpeech(tf.keras.models.Model): class DeepSpeech2(object):
"""DeepSpeech model.""" """Define DeepSpeech2 model."""
def __init__(self, input_shape, num_rnn_layers, rnn_type, is_bidirectional, def __init__(self, num_rnn_layers, rnn_type, is_bidirectional,
rnn_hidden_size, rnn_activation, num_classes, use_bias): rnn_hidden_size, num_classes, use_bias):
"""Initialize DeepSpeech model. """Initialize DeepSpeech2 model.
Args: Args:
input_shape: an tuple to indicate the dimension of input dataset. It has
the format of [time_steps(T), feature_bins(F), channel(1)]
num_rnn_layers: an integer, the number of rnn layers. By default, it's 5. num_rnn_layers: an integer, the number of rnn layers. By default, it's 5.
rnn_type: a string, one of the supported rnn cells: gru, rnn and lstm. rnn_type: a string, one of the supported rnn cells: gru, rnn and lstm.
is_bidirectional: a boolean to indicate if the rnn layer is bidirectional. is_bidirectional: a boolean to indicate if the rnn layer is bidirectional.
rnn_hidden_size: an integer for the number of hidden states in each unit. rnn_hidden_size: an integer for the number of hidden states in each unit.
rnn_activation: a string to indicate rnn activation function. It can be
one of tanh and relu.
num_classes: an integer, the number of output classes/labels. num_classes: an integer, the number of output classes/labels.
use_bias: a boolean specifying whether to use bias in the last fc layer. use_bias: a boolean specifying whether to use bias in the last fc layer.
""" """
# Input variables self.num_rnn_layers = num_rnn_layers
input_data = tf.keras.layers.Input( self.rnn_type = rnn_type
shape=input_shape, name="features") self.is_bidirectional = is_bidirectional
self.rnn_hidden_size = rnn_hidden_size
# Two cnn layers self.num_classes = num_classes
conv_layer_1 = _conv_bn_layer( self.use_bias = use_bias
input_data, filters=32, kernel_size=(41, 11), strides=(2, 2),
layer_id=1) def __call__(self, inputs, training):
# Two cnn layers.
conv_layer_2 = _conv_bn_layer( inputs = _conv_bn_layer(
conv_layer_1, filters=32, kernel_size=(21, 11), strides=(2, 1), inputs, padding=(20, 5), filters=_CONV_FILTERS, kernel_size=(41, 11),
layer_id=2) strides=(2, 2), layer_id=1, training=training)
inputs = _conv_bn_layer(
inputs, padding=(10, 5), filters=_CONV_FILTERS, kernel_size=(21, 11),
strides=(2, 1), layer_id=2, training=training)
# output of conv_layer2 with the shape of # output of conv_layer2 with the shape of
# [batch_size (N), times (T), features (F), channels (C)] # [batch_size (N), times (T), features (F), channels (C)].
# Convert the conv output to rnn input.
batch_size = tf.shape(inputs)[0]
feat_size = inputs.get_shape().as_list()[2]
inputs = tf.reshape(
inputs,
[batch_size, -1, feat_size * _CONV_FILTERS])
# RNN layers. # RNN layers.
# Convert the conv output to rnn input rnn_cell = SUPPORTED_RNNS[self.rnn_type]
rnn_input = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())( for layer_counter in xrange(self.num_rnn_layers):
conv_layer_2) # No batch normalization on the first layer.
rnn_cell = SUPPORTED_RNNS[rnn_type]
for layer_counter in xrange(num_rnn_layers):
# No batch normalization on the first layer
is_batch_norm = (layer_counter != 0) is_batch_norm = (layer_counter != 0)
rnn_input = _rnn_layer( inputs = _rnn_layer(
rnn_input, rnn_cell, rnn_hidden_size, layer_counter + 1, inputs, rnn_cell, self.rnn_hidden_size, layer_counter + 1,
rnn_activation, is_batch_norm, is_bidirectional) is_batch_norm, self.is_bidirectional, training)
# FC layer with batch norm # FC layer with batch norm.
fc_input = tf.keras.layers.BatchNormalization( inputs = batch_norm(inputs, training)
momentum=_MOMENTUM, epsilon=_EPSILON)(rnn_input) logits = tf.layers.dense(inputs, self.num_classes, use_bias=self.use_bias)
y_pred = tf.keras.layers.Dense(num_classes, activation="softmax", return logits
use_bias=use_bias, name="y_pred")(fc_input)
# For ctc loss
labels = tf.keras.layers.Input(name="labels", shape=[None,], dtype="int32")
label_length = tf.keras.layers.Input(
name="label_length", shape=[1], dtype="int32")
input_length = tf.keras.layers.Input(
name="input_length", shape=[1], dtype="int32")
ctc_input_length = tf.keras.layers.Lambda(
_calc_ctc_input_length, output_shape=(1,), name="ctc_input_length")(
[input_length, input_data, y_pred])
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
ctc_loss = tf.keras.layers.Lambda(
_ctc_lambda_func, output_shape=(1,), name="ctc_loss")(
[y_pred, labels, ctc_input_length, label_length])
super(DeepSpeech, self).__init__(
inputs=[input_data, labels, input_length, label_length],
outputs=[ctc_input_length, ctc_loss, y_pred])
nltk>=3.3
soundfile>=0.10.2
sox>=1.3.3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment