Unverified Commit 1635e561 authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Add eval and parallel dataset (#4651)

parent c8c45fdb
......@@ -17,14 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import multiprocessing
import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: disable=g-bad-import-order
from data.featurizer import AudioFeaturizer
from data.featurizer import TextFeaturizer
import data.featurizer as featurizer # pylint: disable=g-bad-import-order
class AudioConfig(object):
......@@ -44,7 +45,7 @@ class AudioConfig(object):
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
normalize: a boolean for whether apply normalization on the audio tensor.
normalize: a boolean for whether apply normalization on the audio feature.
spect_type: a string for the type of spectrogram to be extracted.
"""
......@@ -78,90 +79,122 @@ class DatasetConfig(object):
self.vocab_file_path = vocab_file_path
def _normalize_audio_feature(audio_feature):
"""Perform mean and variance normalization on the spectrogram feature.
Args:
audio_feature: a numpy array for the spectrogram feature.
Returns:
a numpy array of the normalized spectrogram.
"""
mean = np.mean(audio_feature, axis=0)
var = np.var(audio_feature, axis=0)
normalized = (audio_feature - mean) / (np.sqrt(var) + 1e-6)
return normalized
def _preprocess_audio(
audio_file_path, audio_sample_rate, audio_featurizer, normalize):
"""Load the audio file in memory and compute spectrogram feature."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == audio_sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = featurizer.compute_spectrogram_feature(
data, audio_featurizer.frame_length, audio_featurizer.frame_step,
audio_featurizer.fft_length)
if normalize:
feature = _normalize_audio_feature(feature)
return feature
def _preprocess_transcript(transcript, token_to_index):
"""Process transcript as label features."""
return featurizer.compute_label_feature(transcript, token_to_index)
def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
"""Generate a list of waveform, transcript pair.
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
AS the waveforms are ordered in increasing length, audio samples in a
mini-batch have similar length.
Args:
dataset_config: an instance of DatasetConfig.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index
Returns:
features and labels array processed from the audio/text input.
"""
file_path = dataset_config.data_path
sample_rate = dataset_config.audio_config.sample_rate
normalize = dataset_config.audio_config.normalize
with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
lines = lines[1:]
# Sort input data by the length of waveform.
lines.sort(key=lambda item: int(item[1]))
# Use multiprocessing for feature/label extraction
num_cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cores)
features = pool.map(
functools.partial(
_preprocess_audio, audio_sample_rate=sample_rate,
audio_featurizer=audio_featurizer, normalize=normalize),
[line[0] for line in lines])
labels = pool.map(
functools.partial(
_preprocess_transcript, token_to_index=token_to_index),
[line[2] for line in lines])
pool.terminate()
return features, labels
class DeepSpeechDataset(object):
"""Dataset class for training/evaluation of DeepSpeech model."""
def __init__(self, dataset_config):
"""Initialize the class.
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
"""Initialize the DeepSpeechDataset class.
Args:
dataset_config: DatasetConfig object.
"""
self.config = dataset_config
# Instantiate audio feature extractor.
self.audio_featurizer = AudioFeaturizer(
self.audio_featurizer = featurizer.AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length,
frame_step=self.config.audio_config.frame_step,
fft_length=self.config.audio_config.fft_length,
spect_type=self.config.audio_config.spect_type)
fft_length=self.config.audio_config.fft_length)
# Instantiate text feature extractor.
self.text_featurizer = TextFeaturizer(
self.text_featurizer = featurizer.TextFeaturizer(
vocab_file=self.config.vocab_file_path)
self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = self._preprocess_data(self.config.data_path)
self.features, self.labels = _preprocess_data(
self.config,
self.audio_featurizer,
self.text_featurizer.token_to_idx
)
self.num_feature_bins = (
self.features[0].shape[1] if len(self.features) else None)
def _preprocess_data(self, file_path):
"""Generate a list of waveform, transcript pair.
Note that the waveforms are ordered in increasing length, so that audio
samples in a mini-batch have similar length.
Args:
file_path: a string specifying the csv file path for a data set.
Returns:
features and labels array processed from the audio/text input.
"""
with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
lines = lines[1:]
# Sort input data by the length of waveform.
lines.sort(key=lambda item: int(item[1]))
features = [self._preprocess_audio(line[0]) for line in lines]
labels = [self._preprocess_transcript(line[2]) for line in lines]
return features, labels
def _normalize_audio_tensor(self, audio_tensor):
"""Perform mean and variance normalization on the spectrogram tensor.
Args:
audio_tensor: a tensor for the spectrogram feature.
Returns:
a tensor for the normalized spectrogram.
"""
mean, var = tf.nn.moments(audio_tensor, axes=[0])
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6)
return normalized
def _preprocess_audio(self, audio_file_path):
"""Load the audio file in memory."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == self.config.audio_config.sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = self.audio_featurizer.featurize(data)
if self.config.audio_config.normalize:
feature = self._normalize_audio_tensor(feature)
return tf.Session().run(
feature) # return a numpy array rather than a tensor
def _preprocess_transcript(self, transcript):
return self.text_featurizer.featurize(transcript)
def input_fn(batch_size, deep_speech_dataset, repeat=1):
"""Input function for model training and evaluation.
......
......@@ -18,9 +18,21 @@ from __future__ import division
from __future__ import print_function
import codecs
import functools
import numpy as np
import tensorflow as tf
from scipy import signal
def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length):
"""Compute the spectrograms for the input waveform."""
_, _, stft = signal.stft(
waveform,
nperseg=frame_length,
noverlap=frame_step,
nfft=fft_length)
# Perform transpose to set its shape as [time_steps, feature_num_bins]
spectrogram = np.transpose(np.absolute(stft), (1, 0))
return spectrogram
class AudioFeaturizer(object):
......@@ -30,10 +42,7 @@ class AudioFeaturizer(object):
sample_rate=16000,
frame_length=25,
frame_step=10,
fft_length=None,
window_fn=functools.partial(
tf.contrib.signal.hann_window, periodic=True),
spect_type="linear"):
fft_length=None):
"""Initialize the audio featurizer class according to the configs.
Args:
......@@ -41,53 +50,18 @@ class AudioFeaturizer(object):
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_fn: windowing function.
spect_type: a string for the type of spectrogram to be extracted.
Currently only support 'linear', otherwise will raise a value error.
Raises:
ValueError: In case of invalid arguments for `spect_type`.
"""
if spect_type != "linear":
raise ValueError("Unsupported spectrogram type: %s" % spect_type)
self.window_fn = window_fn
self.frame_length = int(sample_rate * frame_length / 1e3)
self.frame_step = int(sample_rate * frame_step / 1e3)
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
np.log2(self.frame_length))))
def featurize(self, waveform):
"""Extract spectrogram feature tensors from the waveform."""
return self._compute_linear_spectrogram(waveform)
def _compute_linear_spectrogram(self, waveform):
"""Compute the linear-scale, magnitude spectrograms for the input waveform.
Args:
waveform: a float32 audio tensor.
Returns:
a float 32 tensor with shape [len, num_bins]
"""
# `stfts` is a complex64 Tensor representing the Short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
stfts = tf.contrib.signal.stft(
waveform,
frame_length=self.frame_length,
frame_step=self.frame_step,
fft_length=self.fft_length,
window_fn=self.window_fn,
pad_end=True)
# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [?, 257].
magnitude_spectrograms = tf.abs(stfts)
return magnitude_spectrograms
def _compute_mel_filterbank_features(self, waveform):
"""Compute the mel filterbank features."""
raise NotImplementedError("MFCC feature extraction not supported yet.")
def compute_label_feature(text, token_to_idx):
"""Convert string to a list of integers."""
tokens = list(text.strip().lower())
feats = [token_to_idx[token] for token in tokens]
return feats
class TextFeaturizer(object):
......@@ -114,9 +88,3 @@ class TextFeaturizer(object):
self.idx_to_token[idx] = line
self.speech_labels += line
idx += 1
def featurize(self, text):
"""Convert string to a list of integers."""
tokens = list(text.strip().lower())
feats = [self.token_to_idx[token] for token in tokens]
return feats
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deep speech decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from nltk.metrics import distance
from six.moves import xrange
import tensorflow as tf
class DeepSpeechDecoder(object):
"""Basic decoder class from which all other decoders inherit.
Implements several helper functions. Subclasses should implement the decode()
method.
"""
def __init__(self, labels, blank_index=28, space_index=27):
"""Decoder initialization.
Arguments:
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character.
Defaults to 0.
space_index (int, optional): index for the space ' ' character.
Defaults to 28.
"""
# e.g. labels = "[a-z]' _"
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
self.space_index = space_index
def convert_to_strings(self, sequences, sizes=None):
"""Given a list of numeric sequences, returns the corresponding strings."""
strings = []
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string = self._convert_to_string(sequences[x], seq_len)
strings.append(string)
return strings
def _convert_to_string(self, sequence, sizes):
return ''.join([self.int_to_char[sequence[i]] for i in range(sizes)])
def process_strings(self, sequences, remove_repetitions=False):
"""Process strings.
Given a list of strings, removes blanks and replace space character with
space. Option to remove repetitions (e.g. 'abbca' -> 'abca').
Arguments:
sequences: list of 1-d array of integers
remove_repetitions (boolean, optional): If true, repeating characters
are removed. Defaults to False.
Returns:
The processed string.
"""
processed_strings = []
for sequence in sequences:
string = self.process_string(remove_repetitions, sequence).strip()
processed_strings.append(string)
return processed_strings
def process_string(self, remove_repetitions, sequence):
"""Process each given sequence."""
seq_string = ''
for i, char in enumerate(sequence):
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true,
# skip.
if remove_repetitions and i != 0 and char == sequence[i - 1]:
pass
elif char == self.labels[self.space_index]:
seq_string += ' '
else:
seq_string += char
return seq_string
def wer(self, output, target):
"""Computes the Word Error Rate (WER).
WER is defined as the edit distance between the two provided sentences after
tokenizing to words.
Args:
output: string of the decoded output.
target: a string for the true transcript.
Returns:
A float number for the WER of the current sentence pair.
"""
# Map each word to a new char.
words = set(output.split() + target.split())
word2char = dict(zip(words, range(len(words))))
new_output = [chr(word2char[w]) for w in output.split()]
new_target = [chr(word2char[w]) for w in target.split()]
return distance.edit_distance(''.join(new_output), ''.join(new_target))
def cer(self, output, target):
"""Computes the Character Error Rate (CER).
CER is defined as the edit distance between the given strings.
Args:
output: a string of the decoded output.
target: a string for the ground truth transcript.
Returns:
A float number denoting the CER for the current sentence pair.
"""
return distance.edit_distance(output, target)
def batch_wer(self, decoded_output, targets):
"""Compute the aggregate WER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated WER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings(
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
wer = 0
for i in xrange(len(decoded_strings)):
wer += self.wer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i].split()))
return wer
def batch_cer(self, decoded_output, targets):
"""Compute the aggregate CER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated CER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings = self.convert_to_strings(decoded_output)
decoded_strings = self.process_strings(
decoded_strings, remove_repetitions=True)
target_strings = self.convert_to_strings(targets)
target_strings = self.process_strings(
target_strings, remove_repetitions=True)
cer = 0
for i in xrange(len(decoded_strings)):
cer += self.cer(decoded_strings[i], target_strings[i]) / float(
len(target_strings[i]))
return cer
def decode(self, sequences, sizes=None):
"""Perform sequence decoding.
Given a matrix of character probabilities, returns the decoder's best guess
of the transcription.
Arguments:
sequences: 2D array of character probabilities, where sequences[c, t] is
the probability of character c at time t.
sizes(optional): Size of each sequence in the mini-batch.
Returns:
string: sequence of the model's best guess for the transcription.
"""
strings = self.convert_to_strings(sequences, sizes)
return self.process_strings(strings, remove_repetitions=True)
class GreedyDecoder(DeepSpeechDecoder):
"""Greedy decoder."""
def decode(self, logits, seq_len):
# Reshape to [max_time, batch_size, num_classes]
logits = tf.transpose(logits, (1, 0, 2))
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
decoded_dense = tf.Session().run(tf.sparse_to_dense(
decoded[0].indices, decoded[0].dense_shape, decoded[0].values))
result = self.convert_to_strings(decoded_dense)
return self.process_strings(result, remove_repetitions=True), decoded_dense
......@@ -25,15 +25,82 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order
import data.dataset as dataset
import decoder
import deep_speech_model
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
# Default vocabulary file
_VOCABULARY_FILE = os.path.join(
os.path.dirname(__file__), "data/vocabulary.txt")
# Evaluation metrics
_WER_KEY = "WER"
_CER_KEY = "CER"
def evaluate_model(
estimator, batch_size, speech_labels, targets, input_fn_eval):
"""Evaluate the model performance using WER anc CER as metrics.
WER: Word Error Rate
CER: Character Error Rate
Args:
estimator: estimator to evaluate.
batch_size: size of a mini-batch.
speech_labels: a string specifying all the character in the vocabulary.
targets: a list of list of integers for the featurized transcript.
input_fn_eval: data input function for evaluation.
Returns:
Evaluation result containing 'wer' and 'cer' as two metrics.
"""
# Get predictions
predictions = estimator.predict(
input_fn=input_fn_eval, yield_single_examples=False)
y_preds = []
input_lengths = []
for p in predictions:
y_preds.append(p["y_pred"])
input_lengths.append(p["ctc_input_length"])
num_of_examples = len(targets)
total_wer, total_cer = 0, 0
greedy_decoder = decoder.GreedyDecoder(speech_labels)
for i in range(len(y_preds)):
# Compute the CER and WER for the current batch,
# and aggregate to total_cer, total_wer.
y_pred_tensor = tf.convert_to_tensor(y_preds[i])
batch_targets = targets[i * batch_size : (i + 1) * batch_size]
seq_len = tf.squeeze(input_lengths[i], axis=1)
# Perform decoding
_, decoded_output = greedy_decoder.decode(
y_pred_tensor, seq_len)
# Compute CER.
batch_cer = greedy_decoder.batch_cer(decoded_output, batch_targets)
total_cer += batch_cer
# Compute WER.
batch_wer = greedy_decoder.batch_wer(decoded_output, batch_targets)
total_wer += batch_wer
# Get mean value
total_cer /= num_of_examples
total_wer /= num_of_examples
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)
eval_results = {
_WER_KEY: total_wer,
_CER_KEY: total_cer,
tf.GraphKeys.GLOBAL_STEP: global_step,
}
return eval_results
def convert_keras_to_estimator(keras_model, num_gpus):
......@@ -136,7 +203,7 @@ def run_deep_speech(_):
return dataset.input_fn(
per_device_batch_size, train_speech_dataset)
def input_fn_eval(): # #pylint: disable=unused-variable
def input_fn_eval():
return dataset.input_fn(
per_device_batch_size, eval_speech_dataset)
......@@ -148,22 +215,23 @@ def run_deep_speech(_):
estimator.train(input_fn=input_fn_train, hooks=train_hooks)
# Evaluate (TODO)
# tf.logging.info("Starting to evaluate.")
# Evaluation
tf.logging.info("Starting to evaluate...")
eval_results = evaluate_model(
estimator, flags_obj.batch_size, eval_speech_dataset.speech_labels,
eval_speech_dataset.labels, input_fn_eval)
# eval_results = evaluate_model(
# estimator, keras_model, data_set.speech_labels, [], input_fn_eval)
# Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results)
tf.logging.info(
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
# benchmark_logger.log_evaluation_result(eval_results)
# If some evaluation threshold is met
# Log the HR and NDCG results.
# wer = eval_results[_WER_KEY]
# cer = eval_results[_CER_KEY]
# tf.logging.info(
# "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
# cycle_index + 1, wer, cer))
# if model_helpers.past_stop_threshold(FLAGS.wer_threshold, wer):
# break
if model_helpers.past_stop_threshold(
flags_obj.wer_threshold, eval_results[_WER_KEY]):
break
# Clear the session explicitly to avoid session delete error
tf.keras.backend.clear_session()
......@@ -189,8 +257,8 @@ def define_deep_speech_flags():
flags_core.set_defaults(
model_dir="/tmp/deep_speech_model/",
export_dir="/tmp/deep_speech_saved_model/",
train_epochs=10,
batch_size=32,
train_epochs=2,
batch_size=4,
hooks="")
# Deep speech flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment