Unverified Commit 3311242a authored by moneypi's avatar moneypi Committed by GitHub
Browse files

1. update to tf2.x for deep_speech (#8696)

Update to TF 2 for deep_speech
parent ccf7da9d
......@@ -71,8 +71,8 @@ class DatasetConfig(object):
"""
self.audio_config = audio_config
assert tf.gfile.Exists(data_path)
assert tf.gfile.Exists(vocab_file_path)
assert tf.io.gfile.exists(data_path)
assert tf.io.gfile.exists(vocab_file_path)
self.data_path = data_path
self.vocab_file_path = vocab_file_path
self.sortagrad = sortagrad
......@@ -125,8 +125,8 @@ def _preprocess_data(file_path):
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size.
"""
tf.logging.info("Loading data set {}".format(file_path))
with tf.gfile.Open(file_path, "r") as f:
tf.compat.v1.logging.info("Loading data set {}".format(file_path))
with tf.io.gfile.GFile(file_path, "r") as f:
lines = f.read().splitlines()
# Skip the csv header in lines[0].
lines = lines[1:]
......
......@@ -59,13 +59,13 @@ def download_and_extract(directory, url):
url: the url to download the data file.
"""
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
if not tf.io.gfile.exists(directory):
tf.io.gfile.makedirs(directory)
_, tar_filepath = tempfile.mkstemp(suffix=".tar.gz")
try:
tf.logging.info("Downloading %s to %s" % (url, tar_filepath))
tf.compat.v1.logging.info("Downloading %s to %s" % (url, tar_filepath))
def _progress(count, block_size, total_size):
sys.stdout.write("\r>> Downloading {} {:.1f}%".format(
......@@ -75,12 +75,12 @@ def download_and_extract(directory, url):
urllib.request.urlretrieve(url, tar_filepath, _progress)
print()
statinfo = os.stat(tar_filepath)
tf.logging.info(
tf.compat.v1.logging.info(
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
with tarfile.open(tar_filepath, "r") as tar:
tar.extractall(directory)
finally:
tf.gfile.Remove(tar_filepath)
tf.io.gfile.remove(tar_filepath)
def convert_audio_and_split_transcript(input_dir, source_name, target_name,
......@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
output_file: the name of the newly generated csv file. e.g. test-clean.csv
"""
tf.logging.info("Preprocessing audio and transcript for %s" % source_name)
tf.compat.v1.logging.info("Preprocessing audio and transcript for %s" % source_name)
source_dir = os.path.join(input_dir, source_name)
target_dir = os.path.join(input_dir, target_name)
if not tf.gfile.Exists(target_dir):
tf.gfile.MakeDirs(target_dir)
if not tf.io.gfile.exists(target_dir):
tf.io.gfile.makedirs(target_dir)
files = []
tfm = Transformer()
# Convert all FLAC file into WAV format. At the same time, generate the csv
# file.
for root, _, filenames in tf.gfile.Walk(source_dir):
for root, _, filenames in tf.io.gfile.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_file = os.path.join(root, filename)
with codecs.open(trans_file, "r", "utf-8") as fin:
......@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
# Convert FLAC to WAV.
flac_file = os.path.join(root, seqid + ".flac")
wav_file = os.path.join(target_dir, seqid + ".wav")
if not tf.gfile.Exists(wav_file):
if not tf.io.gfile.exists(wav_file):
tfm.build(flac_file, wav_file)
wav_filesize = os.path.getsize(wav_file)
......@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
df = pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"])
df.to_csv(csv_file_path, index=False, sep="\t")
tf.logging.info("Successfully generated csv file {}".format(csv_file_path))
tf.compat.v1.logging.info("Successfully generated csv file {}".format(csv_file_path))
def download_and_process_datasets(directory, datasets):
......@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets):
datasets: list of dataset names that will be downloaded and processed.
"""
tf.logging.info("Preparing LibriSpeech dataset: {}".format(
tf.compat.v1.logging.info("Preparing LibriSpeech dataset: {}".format(
",".join(datasets)))
for dataset in datasets:
tf.logging.info("Preparing dataset %s", dataset)
tf.compat.v1.logging.info("Preparing dataset %s", dataset)
dataset_dir = os.path.join(directory, dataset)
download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset])
convert_audio_and_split_transcript(
......@@ -185,8 +185,8 @@ def define_data_download_flags():
def main(_):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MakeDirs(FLAGS.data_dir)
if not tf.io.gfile.exists(FLAGS.data_dir):
tf.io.gfile.makedirs(FLAGS.data_dir)
if FLAGS.train_only:
download_and_process_datasets(
......@@ -202,7 +202,7 @@ def main(_):
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_data_download_flags()
FLAGS = absl_flags.FLAGS
absl_app.run(main)
......@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
Returns:
the ctc_input_length after convolution layer.
"""
ctc_input_length = tf.to_float(tf.multiply(
input_length, ctc_time_steps))
return tf.to_int32(tf.floordiv(
ctc_input_length, tf.to_float(max_time_steps)))
def ctc_loss(label_length, ctc_input_length, labels, logits):
"""Computes the ctc loss for the current batch of predictions."""
label_length = tf.to_int32(tf.squeeze(label_length))
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length))
sparse_labels = tf.to_int32(
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length))
y_pred = tf.log(tf.transpose(
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon())
return tf.expand_dims(
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred,
sequence_length=ctc_input_length),
axis=1)
ctc_input_length = tf.cast(tf.multiply(
input_length, ctc_time_steps), dtype=tf.float32)
return tf.cast(tf.math.floordiv(
ctc_input_length, tf.cast(max_time_steps, dtype=tf.float32)), dtype=tf.int32)
def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
......@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
total_cer /= num_of_examples
total_wer /= num_of_examples
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)
global_step = estimator.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP)
eval_results = {
_WER_KEY: total_wer,
_CER_KEY: total_cer,
tf.GraphKeys.GLOBAL_STEP: global_step,
tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step,
}
return eval_results
......@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
logits = model(features, training=False)
predictions = {
"classes": tf.argmax(logits, axis=2),
"probabilities": tf.nn.softmax(logits),
"probabilities": logits,
"logits": logits
}
return tf.estimator.EstimatorSpec(
......@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
# In training mode.
logits = model(features, training=True)
probs = tf.nn.softmax(logits)
ctc_input_length = compute_length_after_conv(
tf.shape(features)[1], tf.shape(probs)[1], input_length)
tf.shape(features)[1], tf.shape(logits)[1], input_length)
# Compute CTC loss
loss = tf.reduce_mean(ctc_loss(
label_length, ctc_input_length, labels, probs))
loss = tf.reduce_mean(tf.keras.backend.ctc_batch_cost(
labels, logits, ctc_input_length, label_length))
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
global_step = tf.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
minimize_op = optimizer.minimize(loss, global_step=global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
# Create the train_op that groups both minimize_ops and update_ops
train_op = tf.group(minimize_op, update_ops)
......@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
def run_deep_speech(_):
"""Run deep speech training and eval loop."""
tf.set_random_seed(flags_obj.seed)
tf.compat.v1.set_random_seed(flags_obj.seed)
# Data preprocessing
tf.logging.info("Data preprocessing...")
tf.compat.v1.logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
......@@ -287,7 +271,7 @@ def run_deep_speech(_):
total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle):
tf.logging.info("Starting a training cycle: %d/%d",
tf.compat.v1.logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle)
# Perform batch_wise dataset shuffling
......@@ -298,7 +282,7 @@ def run_deep_speech(_):
estimator.train(input_fn=input_fn_train)
# Evaluation
tf.logging.info("Starting to evaluate...")
tf.compat.v1.logging.info("Starting to evaluate...")
eval_results = evaluate_model(
estimator, eval_speech_dataset.speech_labels,
......@@ -306,7 +290,7 @@ def run_deep_speech(_):
# Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results)
tf.logging.info(
tf.compat.v1.logging.info(
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
......@@ -425,7 +409,7 @@ def main(_):
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_deep_speech_flags()
flags_obj = flags.FLAGS
absl_app.run(main)
......
......@@ -22,9 +22,9 @@ import tensorflow as tf
# Supported rnn cells.
SUPPORTED_RNNS = {
"lstm": tf.contrib.rnn.BasicLSTMCell,
"rnn": tf.contrib.rnn.RNNCell,
"gru": tf.contrib.rnn.GRUCell,
"lstm": tf.keras.layers.LSTMCell,
"rnn": tf.keras.layers.SimpleRNNCell,
"gru": tf.keras.layers.GRUCell,
}
# Parameters for batch normalization.
......@@ -53,9 +53,8 @@ def batch_norm(inputs, training):
Returns:
tensor output from batch norm layer.
"""
return tf.layers.batch_normalization(
inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
fused=True, training=training)
return tf.keras.layers.BatchNormalization(
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON)(inputs, training=training)
def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
......@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
inputs = tf.pad(
inputs,
[[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])
inputs = tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
inputs = tf.keras.layers.Conv2D(
filters=filters, kernel_size=kernel_size, strides=strides,
padding="valid", use_bias=False, activation=tf.nn.relu6,
name="cnn_{}".format(layer_id))
name="cnn_{}".format(layer_id))(inputs)
return batch_norm(inputs, training)
......@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
if is_batch_norm:
inputs = batch_norm(inputs, training)
# Construct forward/backward RNN cells.
fw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_fw_{}".format(layer_id))
bw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_bw_{}".format(layer_id))
if is_bidirectional:
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32,
swap_memory=True)
rnn_outputs = tf.concat(outputs, -1)
rnn_outputs = tf.keras.layers.Bidirectional(
tf.keras.layers.RNN(rnn_cell(rnn_hidden_size),
return_sequences=True))(inputs)
else:
rnn_outputs = tf.nn.dynamic_rnn(
fw_cell, inputs, dtype=tf.float32, swap_memory=True)
rnn_outputs = tf.keras.layers.RNN(
rnn_cell(rnn_hidden_size), return_sequences=True)(inputs)
return rnn_outputs
class DeepSpeech2(object):
"""Define DeepSpeech2 model."""
......@@ -179,7 +170,8 @@ class DeepSpeech2(object):
# FC layer with batch norm.
inputs = batch_norm(inputs, training)
logits = tf.layers.dense(inputs, self.num_classes, use_bias=self.use_bias)
logits = tf.keras.layers.Dense(
self.num_classes, use_bias=self.use_bias, activation="softmax")(inputs)
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment