"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "afb7f5b83917b157ccc0b5875904c3d2ca05aa1d"
Unverified Commit 3311242a authored by moneypi's avatar moneypi Committed by GitHub
Browse files

1. update to tf2.x for deep_speech (#8696)

Update to TF 2 for deep_speech
parent ccf7da9d
...@@ -71,8 +71,8 @@ class DatasetConfig(object): ...@@ -71,8 +71,8 @@ class DatasetConfig(object):
""" """
self.audio_config = audio_config self.audio_config = audio_config
assert tf.gfile.Exists(data_path) assert tf.io.gfile.exists(data_path)
assert tf.gfile.Exists(vocab_file_path) assert tf.io.gfile.exists(vocab_file_path)
self.data_path = data_path self.data_path = data_path
self.vocab_file_path = vocab_file_path self.vocab_file_path = vocab_file_path
self.sortagrad = sortagrad self.sortagrad = sortagrad
...@@ -125,8 +125,8 @@ def _preprocess_data(file_path): ...@@ -125,8 +125,8 @@ def _preprocess_data(file_path):
A list of tuples (wav_filename, wav_filesize, transcript) sorted by A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size. file_size.
""" """
tf.logging.info("Loading data set {}".format(file_path)) tf.compat.v1.logging.info("Loading data set {}".format(file_path))
with tf.gfile.Open(file_path, "r") as f: with tf.io.gfile.GFile(file_path, "r") as f:
lines = f.read().splitlines() lines = f.read().splitlines()
# Skip the csv header in lines[0]. # Skip the csv header in lines[0].
lines = lines[1:] lines = lines[1:]
......
...@@ -59,13 +59,13 @@ def download_and_extract(directory, url): ...@@ -59,13 +59,13 @@ def download_and_extract(directory, url):
url: the url to download the data file. url: the url to download the data file.
""" """
if not tf.gfile.Exists(directory): if not tf.io.gfile.exists(directory):
tf.gfile.MakeDirs(directory) tf.io.gfile.makedirs(directory)
_, tar_filepath = tempfile.mkstemp(suffix=".tar.gz") _, tar_filepath = tempfile.mkstemp(suffix=".tar.gz")
try: try:
tf.logging.info("Downloading %s to %s" % (url, tar_filepath)) tf.compat.v1.logging.info("Downloading %s to %s" % (url, tar_filepath))
def _progress(count, block_size, total_size): def _progress(count, block_size, total_size):
sys.stdout.write("\r>> Downloading {} {:.1f}%".format( sys.stdout.write("\r>> Downloading {} {:.1f}%".format(
...@@ -75,12 +75,12 @@ def download_and_extract(directory, url): ...@@ -75,12 +75,12 @@ def download_and_extract(directory, url):
urllib.request.urlretrieve(url, tar_filepath, _progress) urllib.request.urlretrieve(url, tar_filepath, _progress)
print() print()
statinfo = os.stat(tar_filepath) statinfo = os.stat(tar_filepath)
tf.logging.info( tf.compat.v1.logging.info(
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) "Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
with tarfile.open(tar_filepath, "r") as tar: with tarfile.open(tar_filepath, "r") as tar:
tar.extractall(directory) tar.extractall(directory)
finally: finally:
tf.gfile.Remove(tar_filepath) tf.io.gfile.remove(tar_filepath)
def convert_audio_and_split_transcript(input_dir, source_name, target_name, def convert_audio_and_split_transcript(input_dir, source_name, target_name,
...@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name, ...@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
output_file: the name of the newly generated csv file. e.g. test-clean.csv output_file: the name of the newly generated csv file. e.g. test-clean.csv
""" """
tf.logging.info("Preprocessing audio and transcript for %s" % source_name) tf.compat.v1.logging.info("Preprocessing audio and transcript for %s" % source_name)
source_dir = os.path.join(input_dir, source_name) source_dir = os.path.join(input_dir, source_name)
target_dir = os.path.join(input_dir, target_name) target_dir = os.path.join(input_dir, target_name)
if not tf.gfile.Exists(target_dir): if not tf.io.gfile.exists(target_dir):
tf.gfile.MakeDirs(target_dir) tf.io.gfile.makedirs(target_dir)
files = [] files = []
tfm = Transformer() tfm = Transformer()
# Convert all FLAC file into WAV format. At the same time, generate the csv # Convert all FLAC file into WAV format. At the same time, generate the csv
# file. # file.
for root, _, filenames in tf.gfile.Walk(source_dir): for root, _, filenames in tf.io.gfile.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.trans.txt"): for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_file = os.path.join(root, filename) trans_file = os.path.join(root, filename)
with codecs.open(trans_file, "r", "utf-8") as fin: with codecs.open(trans_file, "r", "utf-8") as fin:
...@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name, ...@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
# Convert FLAC to WAV. # Convert FLAC to WAV.
flac_file = os.path.join(root, seqid + ".flac") flac_file = os.path.join(root, seqid + ".flac")
wav_file = os.path.join(target_dir, seqid + ".wav") wav_file = os.path.join(target_dir, seqid + ".wav")
if not tf.gfile.Exists(wav_file): if not tf.io.gfile.exists(wav_file):
tfm.build(flac_file, wav_file) tfm.build(flac_file, wav_file)
wav_filesize = os.path.getsize(wav_file) wav_filesize = os.path.getsize(wav_file)
...@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name, ...@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
df = pandas.DataFrame( df = pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]) data=files, columns=["wav_filename", "wav_filesize", "transcript"])
df.to_csv(csv_file_path, index=False, sep="\t") df.to_csv(csv_file_path, index=False, sep="\t")
tf.logging.info("Successfully generated csv file {}".format(csv_file_path)) tf.compat.v1.logging.info("Successfully generated csv file {}".format(csv_file_path))
def download_and_process_datasets(directory, datasets): def download_and_process_datasets(directory, datasets):
...@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets): ...@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets):
datasets: list of dataset names that will be downloaded and processed. datasets: list of dataset names that will be downloaded and processed.
""" """
tf.logging.info("Preparing LibriSpeech dataset: {}".format( tf.compat.v1.logging.info("Preparing LibriSpeech dataset: {}".format(
",".join(datasets))) ",".join(datasets)))
for dataset in datasets: for dataset in datasets:
tf.logging.info("Preparing dataset %s", dataset) tf.compat.v1.logging.info("Preparing dataset %s", dataset)
dataset_dir = os.path.join(directory, dataset) dataset_dir = os.path.join(directory, dataset)
download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset]) download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset])
convert_audio_and_split_transcript( convert_audio_and_split_transcript(
...@@ -185,8 +185,8 @@ def define_data_download_flags(): ...@@ -185,8 +185,8 @@ def define_data_download_flags():
def main(_): def main(_):
if not tf.gfile.Exists(FLAGS.data_dir): if not tf.io.gfile.exists(FLAGS.data_dir):
tf.gfile.MakeDirs(FLAGS.data_dir) tf.io.gfile.makedirs(FLAGS.data_dir)
if FLAGS.train_only: if FLAGS.train_only:
download_and_process_datasets( download_and_process_datasets(
...@@ -202,7 +202,7 @@ def main(_): ...@@ -202,7 +202,7 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_data_download_flags() define_data_download_flags()
FLAGS = absl_flags.FLAGS FLAGS = absl_flags.FLAGS
absl_app.run(main) absl_app.run(main)
...@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length): ...@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
Returns: Returns:
the ctc_input_length after convolution layer. the ctc_input_length after convolution layer.
""" """
ctc_input_length = tf.to_float(tf.multiply( ctc_input_length = tf.cast(tf.multiply(
input_length, ctc_time_steps)) input_length, ctc_time_steps), dtype=tf.float32)
return tf.to_int32(tf.floordiv( return tf.cast(tf.math.floordiv(
ctc_input_length, tf.to_float(max_time_steps))) ctc_input_length, tf.cast(max_time_steps, dtype=tf.float32)), dtype=tf.int32)
def ctc_loss(label_length, ctc_input_length, labels, logits):
"""Computes the ctc loss for the current batch of predictions."""
label_length = tf.to_int32(tf.squeeze(label_length))
ctc_input_length = tf.to_int32(tf.squeeze(ctc_input_length))
sparse_labels = tf.to_int32(
tf.keras.backend.ctc_label_dense_to_sparse(labels, label_length))
y_pred = tf.log(tf.transpose(
logits, perm=[1, 0, 2]) + tf.keras.backend.epsilon())
return tf.expand_dims(
tf.nn.ctc_loss(labels=sparse_labels, inputs=y_pred,
sequence_length=ctc_input_length),
axis=1)
def evaluate_model(estimator, speech_labels, entries, input_fn_eval): def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
...@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval): ...@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
total_cer /= num_of_examples total_cer /= num_of_examples
total_wer /= num_of_examples total_wer /= num_of_examples
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP) global_step = estimator.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP)
eval_results = { eval_results = {
_WER_KEY: total_wer, _WER_KEY: total_wer,
_CER_KEY: total_cer, _CER_KEY: total_cer,
tf.GraphKeys.GLOBAL_STEP: global_step, tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step,
} }
return eval_results return eval_results
...@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params): ...@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
logits = model(features, training=False) logits = model(features, training=False)
predictions = { predictions = {
"classes": tf.argmax(logits, axis=2), "classes": tf.argmax(logits, axis=2),
"probabilities": tf.nn.softmax(logits), "probabilities": logits,
"logits": logits "logits": logits
} }
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
...@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params): ...@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
# In training mode. # In training mode.
logits = model(features, training=True) logits = model(features, training=True)
probs = tf.nn.softmax(logits)
ctc_input_length = compute_length_after_conv( ctc_input_length = compute_length_after_conv(
tf.shape(features)[1], tf.shape(probs)[1], input_length) tf.shape(features)[1], tf.shape(logits)[1], input_length)
# Compute CTC loss # Compute CTC loss
loss = tf.reduce_mean(ctc_loss( loss = tf.reduce_mean(tf.keras.backend.ctc_batch_cost(
label_length, ctc_input_length, labels, probs)) labels, logits, ctc_input_length, label_length))
optimizer = tf.train.AdamOptimizer(learning_rate=flags_obj.learning_rate) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=flags_obj.learning_rate)
global_step = tf.train.get_or_create_global_step() global_step = tf.compat.v1.train.get_or_create_global_step()
minimize_op = optimizer.minimize(loss, global_step=global_step) minimize_op = optimizer.minimize(loss, global_step=global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
# Create the train_op that groups both minimize_ops and update_ops # Create the train_op that groups both minimize_ops and update_ops
train_op = tf.group(minimize_op, update_ops) train_op = tf.group(minimize_op, update_ops)
...@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus): ...@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
def run_deep_speech(_): def run_deep_speech(_):
"""Run deep speech training and eval loop.""" """Run deep speech training and eval loop."""
tf.set_random_seed(flags_obj.seed) tf.compat.v1.set_random_seed(flags_obj.seed)
# Data preprocessing # Data preprocessing
tf.logging.info("Data preprocessing...") tf.compat.v1.logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir) train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
...@@ -287,7 +271,7 @@ def run_deep_speech(_): ...@@ -287,7 +271,7 @@ def run_deep_speech(_):
total_training_cycle = (flags_obj.train_epochs // total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals) flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle): for cycle_index in range(total_training_cycle):
tf.logging.info("Starting a training cycle: %d/%d", tf.compat.v1.logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle) cycle_index + 1, total_training_cycle)
# Perform batch_wise dataset shuffling # Perform batch_wise dataset shuffling
...@@ -298,7 +282,7 @@ def run_deep_speech(_): ...@@ -298,7 +282,7 @@ def run_deep_speech(_):
estimator.train(input_fn=input_fn_train) estimator.train(input_fn=input_fn_train)
# Evaluation # Evaluation
tf.logging.info("Starting to evaluate...") tf.compat.v1.logging.info("Starting to evaluate...")
eval_results = evaluate_model( eval_results = evaluate_model(
estimator, eval_speech_dataset.speech_labels, estimator, eval_speech_dataset.speech_labels,
...@@ -306,7 +290,7 @@ def run_deep_speech(_): ...@@ -306,7 +290,7 @@ def run_deep_speech(_):
# Log the WER and CER results. # Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
tf.logging.info( tf.compat.v1.logging.info(
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format( "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY])) cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
...@@ -425,7 +409,7 @@ def main(_): ...@@ -425,7 +409,7 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_deep_speech_flags() define_deep_speech_flags()
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
absl_app.run(main) absl_app.run(main)
......
...@@ -22,9 +22,9 @@ import tensorflow as tf ...@@ -22,9 +22,9 @@ import tensorflow as tf
# Supported rnn cells. # Supported rnn cells.
SUPPORTED_RNNS = { SUPPORTED_RNNS = {
"lstm": tf.contrib.rnn.BasicLSTMCell, "lstm": tf.keras.layers.LSTMCell,
"rnn": tf.contrib.rnn.RNNCell, "rnn": tf.keras.layers.SimpleRNNCell,
"gru": tf.contrib.rnn.GRUCell, "gru": tf.keras.layers.GRUCell,
} }
# Parameters for batch normalization. # Parameters for batch normalization.
...@@ -53,9 +53,8 @@ def batch_norm(inputs, training): ...@@ -53,9 +53,8 @@ def batch_norm(inputs, training):
Returns: Returns:
tensor output from batch norm layer. tensor output from batch norm layer.
""" """
return tf.layers.batch_normalization( return tf.keras.layers.BatchNormalization(
inputs=inputs, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON)(inputs, training=training)
fused=True, training=training)
def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id, def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
...@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id, ...@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
inputs = tf.pad( inputs = tf.pad(
inputs, inputs,
[[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]]) [[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]])
inputs = tf.layers.conv2d( inputs = tf.keras.layers.Conv2D(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, filters=filters, kernel_size=kernel_size, strides=strides,
padding="valid", use_bias=False, activation=tf.nn.relu6, padding="valid", use_bias=False, activation=tf.nn.relu6,
name="cnn_{}".format(layer_id)) name="cnn_{}".format(layer_id))(inputs)
return batch_norm(inputs, training) return batch_norm(inputs, training)
...@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm, ...@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
if is_batch_norm: if is_batch_norm:
inputs = batch_norm(inputs, training) inputs = batch_norm(inputs, training)
# Construct forward/backward RNN cells.
fw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_fw_{}".format(layer_id))
bw_cell = rnn_cell(num_units=rnn_hidden_size,
name="rnn_bw_{}".format(layer_id))
if is_bidirectional: if is_bidirectional:
outputs, _ = tf.nn.bidirectional_dynamic_rnn( rnn_outputs = tf.keras.layers.Bidirectional(
cell_fw=fw_cell, cell_bw=bw_cell, inputs=inputs, dtype=tf.float32, tf.keras.layers.RNN(rnn_cell(rnn_hidden_size),
swap_memory=True) return_sequences=True))(inputs)
rnn_outputs = tf.concat(outputs, -1)
else: else:
rnn_outputs = tf.nn.dynamic_rnn( rnn_outputs = tf.keras.layers.RNN(
fw_cell, inputs, dtype=tf.float32, swap_memory=True) rnn_cell(rnn_hidden_size), return_sequences=True)(inputs)
return rnn_outputs return rnn_outputs
class DeepSpeech2(object): class DeepSpeech2(object):
"""Define DeepSpeech2 model.""" """Define DeepSpeech2 model."""
...@@ -179,7 +170,8 @@ class DeepSpeech2(object): ...@@ -179,7 +170,8 @@ class DeepSpeech2(object):
# FC layer with batch norm. # FC layer with batch norm.
inputs = batch_norm(inputs, training) inputs = batch_norm(inputs, training)
logits = tf.layers.dense(inputs, self.num_classes, use_bias=self.use_bias) logits = tf.keras.layers.Dense(
self.num_classes, use_bias=self.use_bias, activation="softmax")(inputs)
return logits return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment