"video_to_video/git@developer.sourcefind.cn:modelzoo/star.git" did not exist on "1f5da52002b7ab0f2fcf95ce4c2ce88ea0995680"
Unverified Commit 2dc6b914 authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Add deep speech model and run loop (#4632)

* Add deep speech model and run loop

* fix lints and add init

* Add dataset and address comments
parent fb3fba0d
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate tf.data.Dataset object for deep speech training/evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: disable=g-bad-import-order
from data.featurizer import AudioFeaturizer
from data.featurizer import TextFeaturizer
class AudioConfig(object):
"""Configs for spectrogram extraction from audio."""
def __init__(self,
sample_rate,
frame_length,
frame_step,
fft_length=None,
normalize=False,
spect_type="linear"):
"""Initialize the AudioConfig class.
Args:
sample_rate: an integer denoting the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
normalize: a boolean for whether apply normalization on the audio tensor.
spect_type: a string for the type of spectrogram to be extracted.
"""
self.sample_rate = sample_rate
self.frame_length = frame_length
self.frame_step = frame_step
self.fft_length = fft_length
self.normalize = normalize
self.spect_type = spect_type
class DatasetConfig(object):
"""Config class for generating the DeepSpeechDataset."""
def __init__(self, audio_config, data_path, vocab_file_path):
"""Initialize the configs for deep speech dataset.
Args:
audio_config: AudioConfig object specifying the audio-related configs.
data_path: a string denoting the full path of a manifest file.
vocab_file_path: a string specifying the vocabulary file path.
Raises:
RuntimeError: file path not exist.
"""
self.audio_config = audio_config
assert tf.gfile.Exists(data_path)
assert tf.gfile.Exists(vocab_file_path)
self.data_path = data_path
self.vocab_file_path = vocab_file_path
class DeepSpeechDataset(object):
"""Dataset class for training/evaluation of DeepSpeech model."""
def __init__(self, dataset_config):
"""Initialize the class.
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
Args:
dataset_config: DatasetConfig object.
"""
self.config = dataset_config
# Instantiate audio feature extractor.
self.audio_featurizer = AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length,
frame_step=self.config.audio_config.frame_step,
fft_length=self.config.audio_config.fft_length,
spect_type=self.config.audio_config.spect_type)
# Instantiate text feature extractor.
self.text_featurizer = TextFeaturizer(
vocab_file=self.config.vocab_file_path)
self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = self._preprocess_data(self.config.data_path)
self.num_feature_bins = (
self.features[0].shape[1] if len(self.features) else None)
def _preprocess_data(self, file_path):
"""Generate a list of waveform, transcript pair.
Note that the waveforms are ordered in increasing length, so that audio
samples in a mini-batch have similar length.
Args:
file_path: a string specifying the csv file path for a data set.
Returns:
features and labels array processed from the audio/text input.
"""
with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
lines = lines[1:]
# Sort input data by the length of waveform.
lines.sort(key=lambda item: int(item[1]))
features = [self._preprocess_audio(line[0]) for line in lines]
labels = [self._preprocess_transcript(line[2]) for line in lines]
return features, labels
def _normalize_audio_tensor(self, audio_tensor):
"""Perform mean and variance normalization on the spectrogram tensor.
Args:
audio_tensor: a tensor for the spectrogram feature.
Returns:
a tensor for the normalized spectrogram.
"""
mean, var = tf.nn.moments(audio_tensor, axes=[0])
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6)
return normalized
def _preprocess_audio(self, audio_file_path):
"""Load the audio file in memory."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == self.config.audio_config.sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = self.audio_featurizer.featurize(data)
if self.config.audio_config.normalize:
feature = self._normalize_audio_tensor(feature)
return tf.Session().run(
feature) # return a numpy array rather than a tensor
def _preprocess_transcript(self, transcript):
return self.text_featurizer.featurize(transcript)
def input_fn(batch_size, deep_speech_dataset, repeat=1):
"""Input function for model training and evaluation.
Args:
batch_size: an integer denoting the size of a batch.
deep_speech_dataset: DeepSpeechDataset object.
repeat: an integer for how many times to repeat the dataset.
Returns:
a tf.data.Dataset object for model to consume.
"""
features = deep_speech_dataset.features
labels = deep_speech_dataset.labels
num_feature_bins = deep_speech_dataset.num_feature_bins
def _gen_data():
for i in xrange(len(features)):
feature = np.expand_dims(features[i], axis=2)
input_length = [features[i].shape[0]]
label_length = [len(labels[i])]
yield {
"features": feature,
"labels": labels[i],
"input_length": input_length,
"label_length": label_length
}
dataset = tf.data.Dataset.from_generator(
_gen_data,
output_types={
"features": tf.float32,
"labels": tf.int32,
"input_length": tf.int32,
"label_length": tf.int32
},
output_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
# Repeat and batch the dataset
dataset = dataset.repeat(repeat)
# Padding the features to its max length dimensions.
dataset = dataset.padded_batch(
batch_size=batch_size,
padded_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
# Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(1)
return dataset
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility class for extracting features from the text and audio input."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import codecs
import functools
import numpy as np
import tensorflow as tf
class AudioFeaturizer(object):
"""Class to extract spectrogram features from the audio input."""
def __init__(self,
sample_rate=16000,
frame_length=25,
frame_step=10,
fft_length=None,
window_fn=functools.partial(
tf.contrib.signal.hann_window, periodic=True),
spect_type="linear"):
"""Initialize the audio featurizer class according to the configs.
Args:
sample_rate: an integer specifying the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_fn: windowing function.
spect_type: a string for the type of spectrogram to be extracted.
Currently only support 'linear', otherwise will raise a value error.
Raises:
ValueError: In case of invalid arguments for `spect_type`.
"""
if spect_type != "linear":
raise ValueError("Unsupported spectrogram type: %s" % spect_type)
self.window_fn = window_fn
self.frame_length = int(sample_rate * frame_length / 1e3)
self.frame_step = int(sample_rate * frame_step / 1e3)
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
np.log2(self.frame_length))))
def featurize(self, waveform):
"""Extract spectrogram feature tensors from the waveform."""
return self._compute_linear_spectrogram(waveform)
def _compute_linear_spectrogram(self, waveform):
"""Compute the linear-scale, magnitude spectrograms for the input waveform.
Args:
waveform: a float32 audio tensor.
Returns:
a float 32 tensor with shape [len, num_bins]
"""
# `stfts` is a complex64 Tensor representing the Short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
stfts = tf.contrib.signal.stft(
waveform,
frame_length=self.frame_length,
frame_step=self.frame_step,
fft_length=self.fft_length,
window_fn=self.window_fn,
pad_end=True)
# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [?, 257].
magnitude_spectrograms = tf.abs(stfts)
return magnitude_spectrograms
def _compute_mel_filterbank_features(self, waveform):
"""Compute the mel filterbank features."""
raise NotImplementedError("MFCC feature extraction not supported yet.")
class TextFeaturizer(object):
"""Extract text feature based on char-level granularity.
By looking up the vocabulary table, each input string (one line of transcript)
will be converted to a sequence of integer indexes.
"""
def __init__(self, vocab_file):
lines = []
with codecs.open(vocab_file, "r", "utf-8") as fin:
lines.extend(fin.readlines())
self.token_to_idx = {}
self.idx_to_token = {}
self.speech_labels = ""
idx = 0
for line in lines:
line = line[:-1] # Strip the '\n' char.
if line.startswith("#"):
# Skip from reading comment line.
continue
self.token_to_idx[line] = idx
self.idx_to_token[idx] = line
self.speech_labels += line
idx += 1
def featurize(self, text):
"""Convert string to a list of integers."""
tokens = list(text.strip().lower())
feats = [self.token_to_idx[token] for token in tokens]
return feats
# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which
# will be ignored by the parser.
# begin of vocabulary
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
'
-
# end of vocabulary
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main entry to train and evaluate DeepSpeech model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
import data.dataset as dataset
import deep_speech_model
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import distribution_utils
# Default vocabulary file
_VOCABULARY_FILE = os.path.join(
os.path.dirname(__file__), "data/vocabulary.txt")
def convert_keras_to_estimator(keras_model, num_gpus):
"""Configure and convert keras model to Estimator.
Args:
keras_model: A Keras model object.
num_gpus: An integer, the number of GPUs.
Returns:
estimator: The converted Estimator.
"""
# keras optimizer is not compatible with distribution strategy.
# Use tf optimizer instead
optimizer = tf.train.MomentumOptimizer(
learning_rate=flags_obj.learning_rate, momentum=flags_obj.momentum,
use_nesterov=True)
# ctc_loss is wrapped as a Lambda layer in the model.
keras_model.compile(
optimizer=optimizer, loss={"ctc_loss": lambda y_true, y_pred: y_pred})
distribution_strategy = distribution_utils.get_distribution_strategy(
num_gpus)
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy)
estimator = tf.keras.estimator.model_to_estimator(
keras_model=keras_model, model_dir=flags_obj.model_dir, config=run_config)
return estimator
def generate_dataset(data_dir):
"""Generate a speech dataset."""
audio_conf = dataset.AudioConfig(
flags_obj.sample_rate, flags_obj.frame_length, flags_obj.frame_step)
train_data_conf = dataset.DatasetConfig(
audio_conf,
data_dir,
flags_obj.vocabulary_file,
)
speech_dataset = dataset.DeepSpeechDataset(train_data_conf)
return speech_dataset
def run_deep_speech(_):
"""Run deep speech training and eval loop."""
# Data preprocessing
# The file name of training and test dataset
tf.logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
# Number of label classes. Label string is "[a-z]' -"
num_classes = len(train_speech_dataset.speech_labels)
# Input shape of each data example:
# [time_steps (T), feature_bins(F), channel(C)]
# Channel is set as 1 by default.
input_shape = (None, train_speech_dataset.num_feature_bins, 1)
# Create deep speech model and convert it to Estimator
tf.logging.info("Creating Estimator from Keras model...")
keras_model = deep_speech_model.DeepSpeech(
input_shape, flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
flags_obj.rnn_activation, num_classes, flags_obj.use_bias)
# Convert to estimator
num_gpus = flags_core.get_num_gpus(flags_obj)
estimator = convert_keras_to_estimator(keras_model, num_gpus)
# Benchmark logging
run_params = {
"batch_size": flags_obj.batch_size,
"train_epochs": flags_obj.train_epochs,
"rnn_hidden_size": flags_obj.rnn_hidden_size,
"rnn_hidden_layers": flags_obj.rnn_hidden_layers,
"rnn_activation": flags_obj.rnn_activation,
"rnn_type": flags_obj.rnn_type,
"is_bidirectional": flags_obj.is_bidirectional,
"use_bias": flags_obj.use_bias
}
dataset_name = "LibriSpeech"
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info("deep_speech", dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
batch_size=flags_obj.batch_size)
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, num_gpus)
def input_fn_train():
return dataset.input_fn(
per_device_batch_size, train_speech_dataset)
def input_fn_eval(): # #pylint: disable=unused-variable
return dataset.input_fn(
per_device_batch_size, eval_speech_dataset)
total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle):
tf.logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle)
estimator.train(input_fn=input_fn_train, hooks=train_hooks)
# Evaluate (TODO)
# tf.logging.info("Starting to evaluate.")
# eval_results = evaluate_model(
# estimator, keras_model, data_set.speech_labels, [], input_fn_eval)
# benchmark_logger.log_evaluation_result(eval_results)
# If some evaluation threshold is met
# Log the HR and NDCG results.
# wer = eval_results[_WER_KEY]
# cer = eval_results[_CER_KEY]
# tf.logging.info(
# "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
# cycle_index + 1, wer, cer))
# if model_helpers.past_stop_threshold(FLAGS.wer_threshold, wer):
# break
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
def define_deep_speech_flags():
"""Add flags for run_deep_speech."""
# Add common flags
flags_core.define_base(
data_dir=False # we use train_data_dir and eval_data_dir instead
)
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False
)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(
model_dir="/tmp/deep_speech_model/",
export_dir="/tmp/deep_speech_saved_model/",
train_epochs=10,
batch_size=32,
hooks="")
# Deep speech flags
flags.DEFINE_string(
name="train_data_dir",
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
help=flags_core.help_wrap("The csv file path of train dataset."))
flags.DEFINE_string(
name="eval_data_dir",
default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
help=flags_core.help_wrap("The csv file path of evaluation dataset."))
flags.DEFINE_integer(
name="sample_rate", default=16000,
help=flags_core.help_wrap("The sample rate for audio."))
flags.DEFINE_integer(
name="frame_length", default=25,
help=flags_core.help_wrap("The frame length for spectrogram."))
flags.DEFINE_integer(
name="frame_step", default=10,
help=flags_core.help_wrap("The frame step."))
flags.DEFINE_string(
name="vocabulary_file", default=_VOCABULARY_FILE,
help=flags_core.help_wrap("The file path of vocabulary file."))
# RNN related flags
flags.DEFINE_integer(
name="rnn_hidden_size", default=256,
help=flags_core.help_wrap("The hidden size of RNNs."))
flags.DEFINE_integer(
name="rnn_hidden_layers", default=3,
help=flags_core.help_wrap("The number of RNN layers."))
flags.DEFINE_bool(
name="use_bias", default=True,
help=flags_core.help_wrap("Use bias in the last fully-connected layer"))
flags.DEFINE_bool(
name="is_bidirectional", default=True,
help=flags_core.help_wrap("If rnn unit is bidirectional"))
flags.DEFINE_enum(
name="rnn_type", default="gru",
enum_values=deep_speech_model.SUPPORTED_RNNS.keys(),
case_sensitive=False,
help=flags_core.help_wrap("Type of RNN cell."))
flags.DEFINE_enum(
name="rnn_activation", default="tanh",
enum_values=["tanh", "relu"], case_sensitive=False,
help=flags_core.help_wrap("Type of the activation within RNN."))
# Training related flags
flags.DEFINE_float(
name="learning_rate", default=0.0003,
help=flags_core.help_wrap("The initial learning rate."))
flags.DEFINE_float(
name="momentum", default=0.9,
help=flags_core.help_wrap("Momentum to accelerate SGD optimizer."))
# Evaluation metrics threshold
flags.DEFINE_float(
name="wer_threshold", default=None,
help=flags_core.help_wrap(
"If passed, training will stop when the evaluation metric WER is "
"greater than or equal to wer_threshold. For libri speech dataset "
"the desired wer_threshold is 0.23 which is the result achieved by "
"MLPerf implementation."))
def main(_):
with logger.benchmark_context(flags_obj):
run_deep_speech(flags_obj)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_deep_speech_flags()
flags_obj = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network structure for DeepSpeech model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# Supported rnn cells
SUPPORTED_RNNS = {
"lstm": tf.keras.layers.LSTM,
"rnn": tf.keras.layers.SimpleRNN,
"gru": tf.keras.layers.GRU,
}
# Parameters for batch normalization
_MOMENTUM = 0.1
_EPSILON = 1e-05
def _conv_bn_layer(cnn_input, filters, kernel_size, strides, layer_id):
"""2D convolution + batch normalization layer.
Args:
cnn_input: input data for convolution layer.
filters: an integer, number of output filters in the convolution.
kernel_size: a tuple specifying the height and width of the 2D convolution
window.
strides: a tuple specifying the stride length of the convolution.
layer_id: an integer specifying the layer index.
Returns:
tensor output from the current layer.
"""
output = tf.keras.layers.Conv2D(
filters=filters, kernel_size=kernel_size, strides=strides, padding="same",
activation="linear", name="cnn_{}".format(layer_id))(cnn_input)
output = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(output)
return output
def _rnn_layer(input_data, rnn_cell, rnn_hidden_size, layer_id, rnn_activation,
is_batch_norm, is_bidirectional):
"""Defines a batch normalization + rnn layer.
Args:
input_data: input tensors for the current layer.
rnn_cell: RNN cell instance to use.
rnn_hidden_size: an integer for the dimensionality of the rnn output space.
layer_id: an integer for the index of current layer.
rnn_activation: activation function to use.
is_batch_norm: a boolean specifying whether to perform batch normalization
on input states.
is_bidirectional: a boolean specifying whether the rnn layer is
bi-directional.
Returns:
tensor output for the current layer.
"""
if is_batch_norm:
input_data = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(input_data)
rnn_layer = rnn_cell(
rnn_hidden_size, activation=rnn_activation, return_sequences=True,
name="rnn_{}".format(layer_id))
if is_bidirectional:
rnn_layer = tf.keras.layers.Bidirectional(rnn_layer, merge_mode="sum")
return rnn_layer(input_data)
def _ctc_lambda_func(args):
"""Compute ctc loss."""
# py2 needs explicit tf import for keras Lambda layer
import tensorflow as tf
y_pred, labels, input_length, label_length = args
return tf.keras.backend.ctc_batch_cost(
labels, y_pred, input_length, label_length)
def _calc_ctc_input_length(args):
"""Compute the actual input length after convolution for ctc_loss function.
Basically, we need to know the scaled input_length after conv layers.
new_input_length = old_input_length * ctc_time_steps / max_time_steps
Args:
args: the input args to compute ctc input length.
Returns:
ctc_input_length, which is required for ctc loss calculation.
"""
# py2 needs explicit tf import for keras Lambda layer
import tensorflow as tf
input_length, input_data, y_pred = args
max_time_steps = tf.shape(input_data)[1]
ctc_time_steps = tf.shape(y_pred)[1]
ctc_input_length = tf.multiply(
tf.to_float(input_length), tf.to_float(ctc_time_steps))
ctc_input_length = tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
return ctc_input_length
class DeepSpeech(tf.keras.models.Model):
"""DeepSpeech model."""
def __init__(self, input_shape, num_rnn_layers, rnn_type, is_bidirectional,
rnn_hidden_size, rnn_activation, num_classes, use_bias):
"""Initialize DeepSpeech model.
Args:
input_shape: an tuple to indicate the dimension of input dataset. It has
the format of [time_steps(T), feature_bins(F), channel(1)]
num_rnn_layers: an integer, the number of rnn layers. By default, it's 5.
rnn_type: a string, one of the supported rnn cells: gru, rnn and lstm.
is_bidirectional: a boolean to indicate if the rnn layer is bidirectional.
rnn_hidden_size: an integer for the number of hidden states in each unit.
rnn_activation: a string to indicate rnn activation function. It can be
one of tanh and relu.
num_classes: an integer, the number of output classes/labels.
use_bias: a boolean specifying whether to use bias in the last fc layer.
"""
# Input variables
input_data = tf.keras.layers.Input(
shape=input_shape, name="features")
# Two cnn layers
conv_layer_1 = _conv_bn_layer(
input_data, filters=32, kernel_size=(41, 11), strides=(2, 2),
layer_id=1)
conv_layer_2 = _conv_bn_layer(
conv_layer_1, filters=32, kernel_size=(21, 11), strides=(2, 1),
layer_id=2)
# output of conv_layer2 with the shape of
# [batch_size (N), times (T), features (F), channels (C)]
# RNN layers.
# Convert the conv output to rnn input
rnn_input = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(
conv_layer_2)
rnn_cell = SUPPORTED_RNNS[rnn_type]
for layer_counter in xrange(num_rnn_layers):
# No batch normalization on the first layer
is_batch_norm = (layer_counter != 0)
rnn_input = _rnn_layer(
rnn_input, rnn_cell, rnn_hidden_size, layer_counter + 1,
rnn_activation, is_batch_norm, is_bidirectional)
# FC layer with batch norm
fc_input = tf.keras.layers.BatchNormalization(
momentum=_MOMENTUM, epsilon=_EPSILON)(rnn_input)
y_pred = tf.keras.layers.Dense(num_classes, activation="softmax",
use_bias=use_bias, name="y_pred")(fc_input)
# For ctc loss
labels = tf.keras.layers.Input(name="labels", shape=[None,], dtype="int32")
label_length = tf.keras.layers.Input(
name="label_length", shape=[1], dtype="int32")
input_length = tf.keras.layers.Input(
name="input_length", shape=[1], dtype="int32")
ctc_input_length = tf.keras.layers.Lambda(
_calc_ctc_input_length, output_shape=(1,), name="ctc_input_length")(
[input_length, input_data, y_pred])
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
ctc_loss = tf.keras.layers.Lambda(
_ctc_lambda_func, output_shape=(1,), name="ctc_loss")(
[y_pred, labels, ctc_input_length, label_length])
super(DeepSpeech, self).__init__(
inputs=[input_data, labels, input_length, label_length],
outputs=[ctc_input_length, ctc_loss, y_pred])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment