Commit d90f5580 authored by Haoliang Zhang's avatar Haoliang Zhang Committed by Yanhui Liang
Browse files

Update deep speech model with pure tensorflow API implementation (#4730)

* update

* update

* update

* update

* update
parent 37ba2304
# DeepSpeech2 Model
## Overview
This is an implementation of the [DeepSpeech2](https://arxiv.org/pdf/1512.02595.pdf) model. Current implementation is based on the code from the authors' [DeepSpeech code](https://github.com/PaddlePaddle/DeepSpeech) and the implementation in the [MLPerf Repo](https://github.com/mlperf/reference/tree/master/speech_recognition).
DeepSpeech2 is an end-to-end deep neural network for automatic speech
recognition (ASR). It consists of 2 convolutional layers, 5 bidirectional RNN
layers and a fully connected layer. The feature in use is linear spectrogram
extracted from audio input. The network uses Connectionist Temporal Classification [CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf) as the loss function.
## Dataset
The [OpenSLR LibriSpeech Corpus](http://www.openslr.org/12/) are used for model training and evaluation.
The training data is a combination of train-clean-100 and train-clean-360 (~130k
examples in total). The validation set is dev-clean which has 2.7K lines.
The download script will preprocess the data into three columns: wav_filename,
wav_filesize, transcript. data/dataset.py will parse the csv file and build a
tf.data.Dataset object to feed data. Within each epoch (except for the
first if sortagrad is enabled), the training data will be shuffled batch-wise.
## Running Code
### Configure Python path
Add the top-level /models folder to the Python path with the command:
```
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
### Install dependencies
First install shared dependencies before running the code. Issue the following command:
```
pip3 install -r requirements.txt
```
or
```
pip install -r requirements.txt
```
### Download and preprocess dataset
To download the dataset, issue the following command:
```
python data/download.py
```
Arguments:
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/librispeech_data`.
Use the `--help` or `-h` flag to get a full list of possible arguments.
### Train and evaluate model
To train and evaluate the model, issue the following command:
```
python deep_speech.py
```
Arguments:
* `--model_dir`: Directory to save model training checkpoints. By default, it is `/tmp/deep_speech_model/`.
* `--train_data_dir`: Directory of the training dataset.
* `--eval_data_dir`: Directory of the evaluation dataset.
* `--num_gpus`: Number of GPUs to use (specify -1 if you want to use all available GPUs).
There are other arguments about DeepSpeech2 model and training/evaluation process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.
......@@ -17,13 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import multiprocessing
import math
import random
# pylint: disable=g-bad-import-order
import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin
import soundfile
import tensorflow as tf
# pylint: enable=g-bad-import-order
import data.featurizer as featurizer # pylint: disable=g-bad-import-order
......@@ -33,40 +34,37 @@ class AudioConfig(object):
def __init__(self,
sample_rate,
frame_length,
frame_step,
fft_length=None,
normalize=False,
spect_type="linear"):
window_ms,
stride_ms,
normalize=False):
"""Initialize the AudioConfig class.
Args:
sample_rate: an integer denoting the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_ms: an integer for the length of a spectrogram frame, in ms.
stride_ms: an integer for the frame stride, in ms.
normalize: a boolean for whether apply normalization on the audio feature.
spect_type: a string for the type of spectrogram to be extracted.
"""
self.sample_rate = sample_rate
self.frame_length = frame_length
self.frame_step = frame_step
self.fft_length = fft_length
self.window_ms = window_ms
self.stride_ms = stride_ms
self.normalize = normalize
self.spect_type = spect_type
class DatasetConfig(object):
"""Config class for generating the DeepSpeechDataset."""
def __init__(self, audio_config, data_path, vocab_file_path):
def __init__(self, audio_config, data_path, vocab_file_path, sortagrad):
"""Initialize the configs for deep speech dataset.
Args:
audio_config: AudioConfig object specifying the audio-related configs.
data_path: a string denoting the full path of a manifest file.
vocab_file_path: a string specifying the vocabulary file path.
sortagrad: a boolean, if set to true, audio sequences will be fed by
increasing length in the first training epoch, which will
expedite network convergence.
Raises:
RuntimeError: file path not exist.
......@@ -77,6 +75,7 @@ class DatasetConfig(object):
assert tf.gfile.Exists(vocab_file_path)
self.data_path = data_path
self.vocab_file_path = vocab_file_path
self.sortagrad = sortagrad
def _normalize_audio_feature(audio_feature):
......@@ -95,30 +94,23 @@ def _normalize_audio_feature(audio_feature):
return normalized
def _preprocess_audio(
audio_file_path, audio_sample_rate, audio_featurizer, normalize):
"""Load the audio file in memory and compute spectrogram feature."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == audio_sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
def _preprocess_audio(audio_file_path, audio_featurizer, normalize):
"""Load the audio file and compute spectrogram feature."""
data, _ = soundfile.read(audio_file_path)
feature = featurizer.compute_spectrogram_feature(
data, audio_featurizer.frame_length, audio_featurizer.frame_step,
audio_featurizer.fft_length)
data, audio_featurizer.sample_rate, audio_featurizer.stride_ms,
audio_featurizer.window_ms)
# Feature normalization
if normalize:
feature = _normalize_audio_feature(feature)
return feature
def _preprocess_transcript(transcript, token_to_index):
"""Process transcript as label features."""
return featurizer.compute_label_feature(transcript, token_to_index)
# Adding Channel dimension for conv2D input.
feature = np.expand_dims(feature, axis=2)
return feature
def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
"""Generate a list of waveform, transcript pair.
def _preprocess_data(file_path):
"""Generate a list of tuples (wav_filename, wav_filesize, transcript).
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
......@@ -127,42 +119,23 @@ def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
mini-batch have similar length.
Args:
dataset_config: an instance of DatasetConfig.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index
file_path: a string specifying the csv file path for a dataset.
Returns:
features and labels array processed from the audio/text input.
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size.
"""
file_path = dataset_config.data_path
sample_rate = dataset_config.audio_config.sample_rate
normalize = dataset_config.audio_config.normalize
tf.logging.info("Loading data set {}".format(file_path))
with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
# Skip the csv header in lines[0].
lines = lines[1:]
# Sort input data by the length of waveform.
# The metadata file is tab separated.
lines = [line.split("\t", 2) for line in lines]
# Sort input data by the length of audio sequence.
lines.sort(key=lambda item: int(item[1]))
# Use multiprocessing for feature/label extraction
num_cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cores)
features = pool.map(
functools.partial(
_preprocess_audio, audio_sample_rate=sample_rate,
audio_featurizer=audio_featurizer, normalize=normalize),
[line[0] for line in lines])
labels = pool.map(
functools.partial(
_preprocess_transcript, token_to_index=token_to_index),
[line[2] for line in lines])
pool.terminate()
return features, labels
return [tuple(line) for line in lines]
class DeepSpeechDataset(object):
......@@ -178,22 +151,52 @@ class DeepSpeechDataset(object):
# Instantiate audio feature extractor.
self.audio_featurizer = featurizer.AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length,
frame_step=self.config.audio_config.frame_step,
fft_length=self.config.audio_config.fft_length)
window_ms=self.config.audio_config.window_ms,
stride_ms=self.config.audio_config.stride_ms)
# Instantiate text feature extractor.
self.text_featurizer = featurizer.TextFeaturizer(
vocab_file=self.config.vocab_file_path)
self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = _preprocess_data(
self.config,
self.audio_featurizer,
self.text_featurizer.token_to_idx
)
self.entries = _preprocess_data(self.config.data_path)
# The generated spectrogram will have 161 feature bins.
self.num_feature_bins = 161
self.num_feature_bins = (
self.features[0].shape[1] if len(self.features) else None)
def batch_wise_dataset_shuffle(entries, epoch_index, sortagrad, batch_size):
"""Batch-wise shuffling of the data entries.
Each data entry is in the format of (audio_file, file_size, transcript).
If epoch_index is 0 and sortagrad is true, we don't perform shuffling and
return entries in sorted file_size order. Otherwise, do batch_wise shuffling.
Args:
entries: a list of data entries.
epoch_index: an integer of epoch index
sortagrad: a boolean to control whether sorting the audio in the first
training epoch.
batch_size: an integer for the batch size.
Returns:
The shuffled data entries.
"""
shuffled_entries = []
if epoch_index == 0 and sortagrad:
# No need to shuffle.
shuffled_entries = entries
else:
# Shuffle entries batch-wise.
max_buckets = int(math.floor(len(entries) / batch_size))
total_buckets = [i for i in xrange(max_buckets)]
random.shuffle(total_buckets)
shuffled_entries = []
for i in total_buckets:
shuffled_entries.extend(entries[i * batch_size : (i + 1) * batch_size])
# If the last batch doesn't contain enough batch_size examples,
# just append it to the shuffled_entries.
shuffled_entries.extend(entries[max_buckets * batch_size:])
return shuffled_entries
def input_fn(batch_size, deep_speech_dataset, repeat=1):
......@@ -207,49 +210,66 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1):
Returns:
a tf.data.Dataset object for model to consume.
"""
features = deep_speech_dataset.features
labels = deep_speech_dataset.labels
# Dataset properties
data_entries = deep_speech_dataset.entries
num_feature_bins = deep_speech_dataset.num_feature_bins
audio_featurizer = deep_speech_dataset.audio_featurizer
feature_normalize = deep_speech_dataset.config.audio_config.normalize
text_featurizer = deep_speech_dataset.text_featurizer
def _gen_data():
for i in xrange(len(features)):
feature = np.expand_dims(features[i], axis=2)
input_length = [features[i].shape[0]]
label_length = [len(labels[i])]
yield {
"features": feature,
"labels": labels[i],
"input_length": input_length,
"label_length": label_length
}
"""Dataset generator function."""
for audio_file, _, transcript in data_entries:
features = _preprocess_audio(
audio_file, audio_featurizer, feature_normalize)
labels = featurizer.compute_label_feature(
transcript, text_featurizer.token_to_index)
input_length = [features.shape[0]]
label_length = [len(labels)]
# Yield a tuple of (features, labels) where features is a dict containing
# all info about the actual data features.
yield (
{
"features": features,
"input_length": input_length,
"label_length": label_length
},
labels)
dataset = tf.data.Dataset.from_generator(
_gen_data,
output_types={
"features": tf.float32,
"labels": tf.int32,
"input_length": tf.int32,
"label_length": tf.int32
},
output_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
output_types=(
{
"features": tf.float32,
"input_length": tf.int32,
"label_length": tf.int32
},
tf.int32),
output_shapes=(
{
"features": tf.TensorShape([None, num_feature_bins, 1]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
},
tf.TensorShape([None]))
)
# Repeat and batch the dataset
dataset = dataset.repeat(repeat)
# Padding the features to its max length dimensions.
dataset = dataset.padded_batch(
batch_size=batch_size,
padded_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
padded_shapes=(
{
"features": tf.TensorShape([None, num_feature_bins, 1]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
},
tf.TensorShape([None]))
)
# Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(1)
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset
......@@ -19,20 +19,51 @@ from __future__ import print_function
import codecs
import numpy as np
from scipy import signal
def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length):
"""Compute the spectrograms for the input waveform."""
_, _, stft = signal.stft(
waveform,
nperseg=frame_length,
noverlap=frame_step,
nfft=fft_length)
def compute_spectrogram_feature(samples, sample_rate, stride_ms=10.0,
window_ms=20.0, max_freq=None, eps=1e-14):
"""Compute the spectrograms for the input samples(waveforms).
# Perform transpose to set its shape as [time_steps, feature_num_bins]
spectrogram = np.transpose(np.absolute(stft), (1, 0))
return spectrogram
More about spectrogram computation, please refer to:
https://en.wikipedia.org/wiki/Short-time_Fourier_transform.
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
# Extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# Window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)
fft = fft**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# Prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
# Compute spectrogram feature
ind = np.where(freqs <= max_freq)[0][-1] + 1
specgram = np.log(fft[:ind, :] + eps)
return np.transpose(specgram, (1, 0))
class AudioFeaturizer(object):
......@@ -40,21 +71,18 @@ class AudioFeaturizer(object):
def __init__(self,
sample_rate=16000,
frame_length=25,
frame_step=10,
fft_length=None):
window_ms=20.0,
stride_ms=10.0):
"""Initialize the audio featurizer class according to the configs.
Args:
sample_rate: an integer specifying the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_ms: an integer for the length of a spectrogram frame, in ms.
stride_ms: an integer for the frame stride, in ms.
"""
self.frame_length = int(sample_rate * frame_length / 1e3)
self.frame_step = int(sample_rate * frame_step / 1e3)
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
np.log2(self.frame_length))))
self.sample_rate = sample_rate
self.window_ms = window_ms
self.stride_ms = stride_ms
def compute_label_feature(text, token_to_idx):
......@@ -75,16 +103,16 @@ class TextFeaturizer(object):
lines = []
with codecs.open(vocab_file, "r", "utf-8") as fin:
lines.extend(fin.readlines())
self.token_to_idx = {}
self.idx_to_token = {}
self.token_to_index = {}
self.index_to_token = {}
self.speech_labels = ""
idx = 0
index = 0
for line in lines:
line = line[:-1] # Strip the '\n' char.
if line.startswith("#"):
# Skip from reading comment line.
continue
self.token_to_idx[line] = idx
self.idx_to_token[idx] = line
self.token_to_index[line] = index
self.index_to_token[index] = line
self.speech_labels += line
idx += 1
index += 1
# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which
# will be ignored by the parser.
# begin of vocabulary
a
b
c
......@@ -28,6 +29,5 @@ x
y
z
'
-
# end of vocabulary
......@@ -18,189 +18,78 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from nltk.metrics import distance
from six.moves import xrange
import tensorflow as tf
import numpy as np
class DeepSpeechDecoder(object):
"""Basic decoder class from which all other decoders inherit.
Implements several helper functions. Subclasses should implement the decode()
method.
"""
"""Greedy decoder implementation for Deep Speech model."""
def __init__(self, labels, blank_index=28, space_index=27):
def __init__(self, labels, blank_index=28):
"""Decoder initialization.
Arguments:
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character.
Defaults to 0.
space_index (int, optional): index for the space ' ' character.
labels: a string specifying the speech labels for the decoder to use.
blank_index: an integer specifying index for the blank character.
Defaults to 28.
"""
# e.g. labels = "[a-z]' _"
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
self.space_index = space_index
def convert_to_strings(self, sequences, sizes=None):
"""Given a list of numeric sequences, returns the corresponding strings."""
strings = []
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string = self._convert_to_string(sequences[x], seq_len)
strings.append(string)
return strings
def _convert_to_string(self, sequence, sizes):
return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)])
def process_strings(self, sequences, remove_repetitions=False):
"""Process strings.
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
Given a list of strings, removes blanks and replace space character with
space. Option to remove repetitions (e.g. 'abbca' -> 'abca').
def convert_to_string(self, sequence):
"""Convert a sequence of indexes into corresponding string."""
return ''.join([self.int_to_char[i] for i in sequence])
Arguments:
sequences: list of 1-d array of integers
remove_repetitions (boolean, optional): If true, repeating characters
are removed. Defaults to False.
Returns:
The processed string.
"""
processed_strings = []
for sequence in sequences:
string = self.process_string(remove_repetitions, sequence).strip()
processed_strings.append(string)
return processed_strings
def process_string(self, remove_repetitions, sequence):
"""Process each given sequence."""
seq_string = ''
for i, char in enumerate(sequence):
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true,
# skip.
if remove_repetitions and i != 0 and char == sequence[i - 1]:
pass
elif char == self.labels[self.space_index]:
seq_string += ' '
else:
seq_string += char
return seq_string
def wer(self, output, target):
def wer(self, decode, target):
"""Computes the Word Error Rate (WER).
WER is defined as the edit distance between the two provided sentences after
tokenizing to words.
Args:
output: string of the decoded output.
target: a string for the true transcript.
decode: string of the decoded output.
target: a string for the ground truth label.
Returns:
A float number for the WER of the current sentence pair.
A float number for the WER of the current decode-target pair.
"""
# Map each word to a new char.
words = set(output.split() + target.split())
words = set(decode.split() + target.split())
word2char = dict(zip(words, range(len(words))))
new_output = [chr(word2char[w]) for w in output.split()]
new_decode = [chr(word2char[w]) for w in decode.split()]
new_target = [chr(word2char[w]) for w in target.split()]
return distance.edit_distance(''.join(new_output), ''.join(new_target))
return distance.edit_distance(''.join(new_decode), ''.join(new_target))
def cer(self, output, target):
def cer(self, decode, target):
"""Computes the Character Error Rate (CER).
CER is defined as the edit distance between the given strings.
CER is defined as the edit distance between the two given strings.
Args:
output: a string of the decoded output.
target: a string for the ground truth transcript.
decode: a string of the decoded output.
target: a string for the ground truth label.
Returns:
A float number denoting the CER for the current sentence pair.
"""
return distance.edit_distance(output, target)
def batch_wer(self, decoded_output, targets):
"""Compute the aggregate WER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated WER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings(
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
wer = 0
for i in xrange(len(decoded_strings)):
wer += self.wer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i].split()))
return wer
def batch_cer(self, decoded_output, targets):
"""Compute the aggregate CER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated CER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings(
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
cer = 0
for i in xrange(len(decoded_strings)):
cer += self.cer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i]))
return cer
def decode(self, sequences, sizes=None):
"""Perform sequence decoding.
Given a matrix of character probabilities, returns the decoder's best guess
of the transcription.
Arguments:
sequences: 2D array of character probabilities, where sequences[c, t] is
the probability of character c at time t.
sizes(optional): Size of each sequence in the mini-batch.
Returns:
string: sequence of the model's best guess for the transcription.
"""
strings = self.convert_to_strings(sequences, sizes)
return self.process_strings(strings, remove_repetitions=True)
class GreedyDecoder(DeepSpeechDecoder):
"""Greedy decoder."""
def decode(self, logits, seq_len):
# Reshape to [max_time, batch_size, num_classes]
logits = tf.transpose(logits, (1, 0, 2))
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
decoded_dense = tf.Session().run(tf.sparse_to_dense(
decoded[0].indices, decoded[0].dense_shape, decoded[0].values))
result = self.convert_to_strings(decoded_dense)
return self.process_strings(result, remove_repetitions=True), decoded_dense
return distance.edit_distance(decode, target)
def decode(self, logits):
"""Decode the best guess from logits using greedy algorithm."""
# Choose the class with maximimum probability.
best = list(np.argmax(logits, axis=1))
# Merge repeated chars.
merge = [k for k, _ in itertools.groupby(best)]
# Remove the blank index in the decoded sequence.
merge_remove_blank = []
for k in merge:
if k != self.blank_index:
merge_remove_blank.append(k)
return self.convert_to_string(merge_remove_blank)
......@@ -41,8 +41,50 @@ _WER_KEY = "WER"
_CER_KEY = "CER"
def evaluate_model(
estimator, batch_size, speech_labels, targets, input_fn_eval):
def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
"""Computes the time_steps/ctc_input_length after convolution.
Suppose that the original feature contains two parts:
1) Real spectrogram signals, spanning input_length steps.
2) Padded part with all 0s.
The total length of those two parts is denoted as max_time_steps, which is
the padded length of the current batch. After convolution layers, the time
steps of a spectrogram feature will be decreased. As we know the percentage
of its original length within the entire length, we can compute the time steps
for the signal after conv as follows (using ctc_input_length to denote):
ctc_input_length = (input_length / max_time_steps) * output_length_of_conv.
This length is then fed into ctc loss function to compute loss.
Args:
max_time_steps: max_time_steps for the batch, after padding.
ctc_time_steps: number of timesteps after convolution.
input_length: actual length of the original spectrogram, without padding.
Returns:
the ctc_input_length after convolution layer.
"""
ctc_input_length = tf.to_float(tf.multiply(
input_length, ctc_time_steps))
return tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
def ctc_loss(label_length, ctc_input_length, labels, logits):
"""Computes the ctc loss for the current batch of predictions."""
label_length = tf.to_int32(tf.squeeze(label_length))
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length))
sparse_labels = tf.to_int32(
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length))
y_pred = tf.log(tf.transpose(
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon())
return tf.expand_dims(
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred,
sequence_length=ctc_input_length),
axis=1)
def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
"""Evaluate the model performance using WER anc CER as metrics.
WER: Word Error Rate
......@@ -50,44 +92,34 @@ def evaluate_model(
Args:
estimator: estimator to evaluate.
batch_size: size of a mini-batch.
speech_labels: a string specifying all the character in the vocabulary.
targets: a list of list of integers for the featurized transcript.
entries: a list of data entries (audio_file, file_size, transcript) for the
given dataset.
input_fn_eval: data input function for evaluation.
Returns:
Evaluation result containing 'wer' and 'cer' as two metrics.
"""
# Get predictions
predictions = estimator.predict(
input_fn=input_fn_eval, yield_single_examples=False)
predictions = estimator.predict(input_fn=input_fn_eval)
y_preds = []
input_lengths = []
for p in predictions:
y_preds.append(p["y_pred"])
input_lengths.append(p["ctc_input_length"])
# Get probabilities of each predicted class
probs = [pred["probabilities"] for pred in predictions]
num_of_examples = len(targets)
total_wer, total_cer = 0, 0
greedy_decoder = decoder.GreedyDecoder(speech_labels)
for i in range(len(y_preds)):
# Compute the CER and WER for the current batch,
# and aggregate to total_cer, total_wer.
y_pred_tensor = tf.convert_to_tensor(y_preds[i])
batch_targets = targets[i * batch_size : (i + 1) * batch_size]
seq_len = tf.squeeze(input_lengths[i], axis=1)
# Perform decoding
_, decoded_output = greedy_decoder.decode(
y_pred_tensor, seq_len)
num_of_examples = len(probs)
targets = [entry[2] for entry in entries] # The ground truth transcript
total_wer, total_cer = 0, 0
greedy_decoder = decoder.DeepSpeechDecoder(speech_labels)
for i in range(num_of_examples):
# Decode string.
decoded_str = greedy_decoder.decode(probs[i])
# Compute CER.
batch_cer = greedy_decoder.batch_cer(decoded_output, batch_targets)
total_cer += batch_cer
total_cer += greedy_decoder.cer(decoded_str, targets[i]) / float(
len(targets[i]))
# Compute WER.
batch_wer = greedy_decoder.batch_wer(decoded_output, batch_targets)
total_wer += batch_wer
total_wer += greedy_decoder.wer(decoded_str, targets[i]) / float(
len(targets[i].split()))
# Get mean value
total_cer /= num_of_examples
......@@ -103,45 +135,76 @@ def evaluate_model(
return eval_results
def convert_keras_to_estimator(keras_model, num_gpus):
"""Configure and convert keras model to Estimator.
def model_fn(features, labels, mode, params):
"""Define model function for deep speech model.
Args:
keras_model: A Keras model object.
num_gpus: An integer, the number of GPUs.
features: a dictionary of input_data features. It includes the data
input_length, label_length and the spectrogram features.
labels: a list of labels for the input data.
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`.
params: a dict of hyper parameters to be passed to model_fn.
Returns:
estimator: The converted Estimator.
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# keras optimizer is not compatible with distribution strategy.
# Use tf optimizer instead
optimizer = tf.train.MomentumOptimizer(
learning_rate=flags_obj.learning_rate, momentum=flags_obj.momentum,
use_nesterov=True)
# ctc_loss is wrapped as a Lambda layer in the model.
keras_model.compile(
optimizer=optimizer, loss={"ctc_loss": lambda y_true, y_pred: y_pred})
distribution_strategy = distribution_utils.get_distribution_strategy(
num_gpus)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy)
estimator = tf.keras.estimator.model_to_estimator(
keras_model=keras_model, model_dir=flags_obj.model_dir, config=run_config)
return estimator
num_classes = params["num_classes"]
input_length = features["input_length"]
label_length = features["label_length"]
features = features["features"]
# Create DeepSpeech2 model.
model = deep_speech_model.DeepSpeech2(
flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
num_classes, flags_obj.use_bias)
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(features, training=False)
predictions = {
"classes": tf.argmax(logits, axis=2),
"probabilities": tf.nn.softmax(logits),
"logits": logits
}
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions)
# In training mode.
logits = model(features, training=True)
probs = tf.nn.softmax(logits)
ctc_input_length = compute_length_after_conv(
tf.shape(features)[1], tf.shape(probs)[1], input_length)
# Compute CTC loss
loss = tf.reduce_mean(ctc_loss(
label_length, ctc_input_length, labels, probs))
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
global_step = tf.train.get_or_create_global_step()
minimize_op = optimizer.minimize(loss, global_step=global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Create the train_op that groups both minimize_ops and update_ops
train_op = tf.group(minimize_op, update_ops)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
def generate_dataset(data_dir):
"""Generate a speech dataset."""
audio_conf = dataset.AudioConfig(
flags_obj.sample_rate, flags_obj.frame_length, flags_obj.frame_step)
audio_conf = dataset.AudioConfig(sample_rate=flags_obj.sample_rate,
window_ms=flags_obj.window_ms,
stride_ms=flags_obj.stride_ms,
normalize=True)
train_data_conf = dataset.DatasetConfig(
audio_conf,
data_dir,
flags_obj.vocabulary_file,
flags_obj.sortagrad
)
speech_dataset = dataset.DeepSpeechDataset(train_data_conf)
return speech_dataset
......@@ -150,30 +213,27 @@ def generate_dataset(data_dir):
def run_deep_speech(_):
"""Run deep speech training and eval loop."""
# Data preprocessing
# The file name of training and test dataset
tf.logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
# Number of label classes. Label string is "[a-z]' -"
num_classes = len(train_speech_dataset.speech_labels)
# Input shape of each data example:
# [time_steps (T), feature_bins(F), channel(C)]
# Channel is set as 1 by default.
input_shape = (None, train_speech_dataset.num_feature_bins, 1)
# Create deep speech model and convert it to Estimator
tf.logging.info("Creating Estimator from Keras model...")
keras_model = deep_speech_model.DeepSpeech(
input_shape, flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
flags_obj.rnn_activation, num_classes, flags_obj.use_bias)
# Convert to estimator
# Use distribution strategy for multi-gpu training
num_gpus = flags_core.get_num_gpus(flags_obj)
estimator = convert_keras_to_estimator(keras_model, num_gpus)
distribution_strategy = distribution_utils.get_distribution_strategy(num_gpus)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=flags_obj.model_dir,
config=run_config,
params={
"num_classes": num_classes,
}
)
# Benchmark logging
run_params = {
......@@ -181,7 +241,6 @@ def run_deep_speech(_):
"train_epochs": flags_obj.train_epochs,
"rnn_hidden_size": flags_obj.rnn_hidden_size,
"rnn_hidden_layers": flags_obj.rnn_hidden_layers,
"rnn_activation": flags_obj.rnn_activation,
"rnn_type": flags_obj.rnn_type,
"is_bidirectional": flags_obj.is_bidirectional,
"use_bias": flags_obj.use_bias
......@@ -194,6 +253,7 @@ def run_deep_speech(_):
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
per_device_batch_size = distribution_utils.per_device_batch_size(
......@@ -213,14 +273,19 @@ def run_deep_speech(_):
tf.logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle)
# Perform batch_wise dataset shuffling
train_speech_dataset.entries = dataset.batch_wise_dataset_shuffle(
train_speech_dataset.entries, cycle_index, flags_obj.sortagrad,
flags_obj.batch_size)
estimator.train(input_fn=input_fn_train, hooks=train_hooks)
# Evaluation
tf.logging.info("Starting to evaluate...")
eval_results = evaluate_model(
estimator, flags_obj.batch_size, eval_speech_dataset.speech_labels,
eval_speech_dataset.labels, input_fn_eval)
estimator, eval_speech_dataset.speech_labels,
eval_speech_dataset.entries, input_fn_eval)
# Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results)
......@@ -233,9 +298,6 @@ def run_deep_speech(_):
flags_obj.wer_threshold, eval_results[_WER_KEY]):
break
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
def define_deep_speech_flags():
"""Add flags for run_deep_speech."""
......@@ -257,8 +319,8 @@ def define_deep_speech_flags():
flags_core.set_defaults(
model_dir="/tmp/deep_speech_model/",
export_dir="/tmp/deep_speech_saved_model/",
train_epochs=2,
batch_size=4,
train_epochs=200,
batch_size=128,
hooks="")
# Deep speech flags
......@@ -272,16 +334,22 @@ def define_deep_speech_flags():
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
help=flags_core.help_wrap("The csv file path of evaluation dataset."))
flags.DEFINE_bool(
name="sortagrad", default=True,
help=flags_core.help_wrap(
"If true, sort examples by audio length and perform no "
"batch_wise shuffling for the first epoch."))
flags.DEFINE_integer(
name="sample_rate", default=16000,
help=flags_core.help_wrap("The sample rate for audio."))
flags.DEFINE_integer(
name="frame_length", default=25,
name="window_ms", default=20,
help=flags_core.help_wrap("The frame length for spectrogram."))
flags.DEFINE_integer(
name="frame_step", default=10,
name="stride_ms", default=10,
help=flags_core.help_wrap("The frame step."))
flags.DEFINE_string(
......@@ -290,11 +358,11 @@ def define_deep_speech_flags():
# RNN related flags
flags.DEFINE_integer(
name="rnn_hidden_size", default=256,
name="rnn_hidden_size", default=800,
help=flags_core.help_wrap("The hidden size of RNNs."))
flags.DEFINE_integer(
name="rnn_hidden_layers", default=3,
name="rnn_hidden_layers", default=5,
help=flags_core.help_wrap("The number of RNN layers."))
flags.DEFINE_bool(
......@@ -311,20 +379,11 @@ def define_deep_speech_flags():
case_sensitive=False,
help=flags_core.help_wrap("Type of RNN cell."))
flags.DEFINE_enum(
name="rnn_activation", default="tanh",
enum_values=["tanh", "relu"], case_sensitive=False,
help=flags_core.help_wrap("Type of the activation within RNN."))
# Training related flags
flags.DEFINE_float(
name="learning_rate", default=0.0003,
name="learning_rate", default=5e-4,
help=flags_core.help_wrap("The initial learning rate."))
flags.DEFINE_float(
name="momentum", default=0.9,
help=flags_core.help_wrap("Momentum to accelerate SGD optimizer."))
# Evaluation metrics threshold
flags.DEFINE_float(
name="wer_threshold", default=None,
......@@ -345,3 +404,4 @@ if __name__ == "__main__":
define_deep_speech_flags()
flags_obj = flags.FLAGS
absl_app.run(main)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network structure for DeepSpeech model."""
"""Network structure for DeepSpeech2 model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -20,175 +20,166 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# Supported rnn cells
# Supported rnn cells.
SUPPORTED_RNNS = {
"lstm": tf.keras.layers.LSTM,
"rnn": tf.keras.layers.SimpleRNN,
"gru": tf.keras.layers.GRU,
"lstm": tf.nn.rnn_cell.BasicLSTMCell,
"rnn": tf.nn.rnn_cell.RNNCell,
"gru": tf.nn.rnn_cell.GRUCell,
}
# Parameters for batch normalization
_MOMENTUM = 0.1
_EPSILON = 1e-05
# Parameters for batch normalization.
_BATCH_NORM_EPSILON = 1e-5
_BATCH_NORM_DECAY = 0.997
# Filters of convolution layer
_CONV_FILTERS = 32
def _conv_bn_layer(cnn_input, filters, kernel_size, strides, layer_id):
"""2D convolution + batch normalization layer.
def batch_norm(inputs, training):
"""Batch normalization layer.
Note that the momentum to use will affect validation accuracy over time.
Batch norm has different behaviors during training/evaluation. With a large
momentum, the model takes longer to get a near-accurate estimation of the
moving mean/variance over the entire training dataset, which means we need
more iterations to see good evaluation results. If the training data is evenly
distributed over the feature space, we can also try setting a smaller momentum
(such as 0.1) to get good evaluation result sooner.
Args:
inputs: input data for batch norm layer.
training: a boolean to indicate if it is in training stage.
Returns:
tensor output from batch norm layer.
"""
return tf.layers.batch_normalization(
inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
fused=True, training=training)
def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
training):
"""Defines 2D constitutional + batch normalization layer.
Args:
cnn_input: input data for convolution layer.
inputs: input data for convolution layer.
padding: padding to be applied before convolution layer.
filters: an integer, number of output filters in the convolution.
kernel_size: a tuple specifying the height and width of the 2D convolution
window.
strides: a tuple specifying the stride length of the convolution.
layer_id: an integer specifying the layer index.
training: a boolean to indicate which stage we are in (training/eval).
Returns:
tensor output from the current layer.
"""
output = tf.keras.layers.Conv2D(
filters=filters, kernel_size=kernel_size, strides=strides, padding="same",
activation="linear", name="cnn_{}".format(layer_id))(cnn_input)
output = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(output)
return output
def _rnn_layer(input_data, rnn_cell, rnn_hidden_size, layer_id, rnn_activation,
is_batch_norm, is_bidirectional):
# Perform symmetric padding on the feature dimension of time_step
# This step is required to avoid issues when RNN output sequence is shorter
# than the label length.
inputs = tf.pad(
inputs,
[[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])
inputs = tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
padding="valid", use_bias=False, activation=tf.nn.relu6,
name="cnn_{}".format(layer_id))
return batch_norm(inputs, training)
def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
is_bidirectional, training):
"""Defines a batch normalization + rnn layer.
Args:
input_data: input tensors for the current layer.
inputs: input tensors for the current layer.
rnn_cell: RNN cell instance to use.
rnn_hidden_size: an integer for the dimensionality of the rnn output space.
layer_id: an integer for the index of current layer.
rnn_activation: activation function to use.
is_batch_norm: a boolean specifying whether to perform batch normalization
on input states.
is_bidirectional: a boolean specifying whether the rnn layer is
bi-directional.
training: a boolean to indicate which stage we are in (training/eval).
Returns:
tensor output for the current layer.
"""
if is_batch_norm:
input_data = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(input_data)
rnn_layer = rnn_cell(
rnn_hidden_size, activation=rnn_activation, return_sequences=True,
name="rnn_{}".format(layer_id))
if is_bidirectional:
rnn_layer = tf.keras.layers.Bidirectional(rnn_layer, merge_mode="sum")
return rnn_layer(input_data)
inputs = batch_norm(inputs, training)
def _ctc_lambda_func(args):
"""Compute ctc loss."""
# py2 needs explicit tf import for keras Lambda layer
import tensorflow as tf
# Construct forward/backward RNN cells.
fw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_fw_{}".format(layer_id))
bw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_bw_{}".format(layer_id))
y_pred, labels, input_length, label_length = args
return tf.keras.backend.ctc_batch_cost(
labels, y_pred, input_length, label_length)
def _calc_ctc_input_length(args):
"""Compute the actual input length after convolution for ctc_loss function.
Basically, we need to know the scaled input_length after conv layers.
new_input_length = old_input_length * ctc_time_steps / max_time_steps
Args:
args: the input args to compute ctc input length.
Returns:
ctc_input_length, which is required for ctc loss calculation.
"""
# py2 needs explicit tf import for keras Lambda layer
import tensorflow as tf
if is_bidirectional:
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32,
swap_memory=True)
rnn_outputs = tf.concat(outputs, -1)
else:
rnn_outputs = tf.nn.dynamic_rnn(
fw_cell, inputs, dtype=tf.float32, swap_memory=True)
input_length, input_data, y_pred = args
max_time_steps = tf.shape(input_data)[1]
ctc_time_steps = tf.shape(y_pred)[1]
ctc_input_length = tf.multiply(
tf.to_float(input_length), tf.to_float(ctc_time_steps))
ctc_input_length = tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
return ctc_input_length
return rnn_outputs
class DeepSpeech(tf.keras.models.Model):
"""DeepSpeech model."""
class DeepSpeech2(object):
"""Define DeepSpeech2 model."""
def __init__(self, input_shape, num_rnn_layers, rnn_type, is_bidirectional,
rnn_hidden_size, rnn_activation, num_classes, use_bias):
"""Initialize DeepSpeech model.
def __init__(self, num_rnn_layers, rnn_type, is_bidirectional,
rnn_hidden_size, num_classes, use_bias):
"""Initialize DeepSpeech2 model.
Args:
input_shape: an tuple to indicate the dimension of input dataset. It has
the format of [time_steps(T), feature_bins(F), channel(1)]
num_rnn_layers: an integer, the number of rnn layers. By default, it's 5.
rnn_type: a string, one of the supported rnn cells: gru, rnn and lstm.
is_bidirectional: a boolean to indicate if the rnn layer is bidirectional.
rnn_hidden_size: an integer for the number of hidden states in each unit.
rnn_activation: a string to indicate rnn activation function. It can be
one of tanh and relu.
num_classes: an integer, the number of output classes/labels.
use_bias: a boolean specifying whether to use bias in the last fc layer.
"""
# Input variables
input_data = tf.keras.layers.Input(
shape=input_shape, name="features")
# Two cnn layers
conv_layer_1 = _conv_bn_layer(
input_data, filters=32, kernel_size=(41, 11), strides=(2, 2),
layer_id=1)
conv_layer_2 = _conv_bn_layer(
conv_layer_1, filters=32, kernel_size=(21, 11), strides=(2, 1),
layer_id=2)
self.num_rnn_layers = num_rnn_layers
self.rnn_type = rnn_type
self.is_bidirectional = is_bidirectional
self.rnn_hidden_size = rnn_hidden_size
self.num_classes = num_classes
self.use_bias = use_bias
def __call__(self, inputs, training):
# Two cnn layers.
inputs = _conv_bn_layer(
inputs, padding=(20, 5), filters=_CONV_FILTERS, kernel_size=(41, 11),
strides=(2, 2), layer_id=1, training=training)
inputs = _conv_bn_layer(
inputs, padding=(10, 5), filters=_CONV_FILTERS, kernel_size=(21, 11),
strides=(2, 1), layer_id=2, training=training)
# output of conv_layer2 with the shape of
# [batch_size (N), times (T), features (F), channels (C)]
# [batch_size (N), times (T), features (F), channels (C)].
# Convert the conv output to rnn input.
batch_size = tf.shape(inputs)[0]
feat_size = inputs.get_shape().as_list()[2]
inputs = tf.reshape(
inputs,
[batch_size, -1, feat_size * _CONV_FILTERS])
# RNN layers.
# Convert the conv output to rnn input
rnn_input = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(
conv_layer_2)
rnn_cell = SUPPORTED_RNNS[rnn_type]
for layer_counter in xrange(num_rnn_layers):
# No batch normalization on the first layer
rnn_cell = SUPPORTED_RNNS[self.rnn_type]
for layer_counter in xrange(self.num_rnn_layers):
# No batch normalization on the first layer.
is_batch_norm = (layer_counter != 0)
rnn_input = _rnn_layer(
rnn_input, rnn_cell, rnn_hidden_size, layer_counter + 1,
rnn_activation, is_batch_norm, is_bidirectional)
# FC layer with batch norm
fc_input = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(rnn_input)
y_pred = tf.keras.layers.Dense(num_classes, activation="softmax",
use_bias=use_bias, name="y_pred")(fc_input)
# For ctc loss
labels = tf.keras.layers.Input(name="labels", shape=[None,], dtype="int32")
label_length = tf.keras.layers.Input(
name="label_length", shape=[1], dtype="int32")
input_length = tf.keras.layers.Input(
name="input_length", shape=[1], dtype="int32")
ctc_input_length = tf.keras.layers.Lambda(
_calc_ctc_input_length, output_shape=(1,), name="ctc_input_length")(
[input_length, input_data, y_pred])
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
ctc_loss = tf.keras.layers.Lambda(
_ctc_lambda_func, output_shape=(1,), name="ctc_loss")(
[y_pred, labels, ctc_input_length, label_length])
super(DeepSpeech, self).__init__(
inputs=[input_data, labels, input_length, label_length],
outputs=[ctc_input_length, ctc_loss, y_pred])
inputs = _rnn_layer(
inputs, rnn_cell, self.rnn_hidden_size, layer_counter + 1,
is_batch_norm, self.is_bidirectional, training)
# FC layer with batch norm.
inputs = batch_norm(inputs, training)
logits = tf.layers.dense(inputs, self.num_classes, use_bias=self.use_bias)
return logits
nltk>=3.3
soundfile>=0.10.2
sox>=1.3.3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment