Unverified Commit 2986bcaf authored by moneypi's avatar moneypi Committed by GitHub
Browse files

replace tf.compat.v1.logging with absl.logging for deep_speech (#9222)

parent 785f1a18
...@@ -24,6 +24,7 @@ import numpy as np ...@@ -24,6 +24,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import soundfile import soundfile
import tensorflow as tf import tensorflow as tf
from absl import logging
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
import data.featurizer as featurizer # pylint: disable=g-bad-import-order import data.featurizer as featurizer # pylint: disable=g-bad-import-order
...@@ -125,7 +126,7 @@ def _preprocess_data(file_path): ...@@ -125,7 +126,7 @@ def _preprocess_data(file_path):
A list of tuples (wav_filename, wav_filesize, transcript) sorted by A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size. file_size.
""" """
tf.compat.v1.logging.info("Loading data set {}".format(file_path)) logging.info("Loading data set {}".format(file_path))
with tf.io.gfile.GFile(file_path, "r") as f: with tf.io.gfile.GFile(file_path, "r") as f:
lines = f.read().splitlines() lines = f.read().splitlines()
# Skip the csv header in lines[0]. # Skip the csv header in lines[0].
......
...@@ -32,6 +32,7 @@ import pandas ...@@ -32,6 +32,7 @@ import pandas
from six.moves import urllib from six.moves import urllib
from sox import Transformer from sox import Transformer
import tensorflow as tf import tensorflow as tf
from absl import logging
LIBRI_SPEECH_URLS = { LIBRI_SPEECH_URLS = {
"train-clean-100": "train-clean-100":
...@@ -65,7 +66,7 @@ def download_and_extract(directory, url): ...@@ -65,7 +66,7 @@ def download_and_extract(directory, url):
_, tar_filepath = tempfile.mkstemp(suffix=".tar.gz") _, tar_filepath = tempfile.mkstemp(suffix=".tar.gz")
try: try:
tf.compat.v1.logging.info("Downloading %s to %s" % (url, tar_filepath)) logging.info("Downloading %s to %s" % (url, tar_filepath))
def _progress(count, block_size, total_size): def _progress(count, block_size, total_size):
sys.stdout.write("\r>> Downloading {} {:.1f}%".format( sys.stdout.write("\r>> Downloading {} {:.1f}%".format(
...@@ -75,7 +76,7 @@ def download_and_extract(directory, url): ...@@ -75,7 +76,7 @@ def download_and_extract(directory, url):
urllib.request.urlretrieve(url, tar_filepath, _progress) urllib.request.urlretrieve(url, tar_filepath, _progress)
print() print()
statinfo = os.stat(tar_filepath) statinfo = os.stat(tar_filepath)
tf.compat.v1.logging.info( logging.info(
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) "Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
with tarfile.open(tar_filepath, "r") as tar: with tarfile.open(tar_filepath, "r") as tar:
tar.extractall(directory) tar.extractall(directory)
...@@ -112,7 +113,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name, ...@@ -112,7 +113,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
output_file: the name of the newly generated csv file. e.g. test-clean.csv output_file: the name of the newly generated csv file. e.g. test-clean.csv
""" """
tf.compat.v1.logging.info("Preprocessing audio and transcript for %s" % source_name) logging.info("Preprocessing audio and transcript for %s" % source_name)
source_dir = os.path.join(input_dir, source_name) source_dir = os.path.join(input_dir, source_name)
target_dir = os.path.join(input_dir, target_name) target_dir = os.path.join(input_dir, target_name)
...@@ -149,7 +150,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name, ...@@ -149,7 +150,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
df = pandas.DataFrame( df = pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]) data=files, columns=["wav_filename", "wav_filesize", "transcript"])
df.to_csv(csv_file_path, index=False, sep="\t") df.to_csv(csv_file_path, index=False, sep="\t")
tf.compat.v1.logging.info("Successfully generated csv file {}".format(csv_file_path)) logging.info("Successfully generated csv file {}".format(csv_file_path))
def download_and_process_datasets(directory, datasets): def download_and_process_datasets(directory, datasets):
...@@ -160,10 +161,10 @@ def download_and_process_datasets(directory, datasets): ...@@ -160,10 +161,10 @@ def download_and_process_datasets(directory, datasets):
datasets: list of dataset names that will be downloaded and processed. datasets: list of dataset names that will be downloaded and processed.
""" """
tf.compat.v1.logging.info("Preparing LibriSpeech dataset: {}".format( logging.info("Preparing LibriSpeech dataset: {}".format(
",".join(datasets))) ",".join(datasets)))
for dataset in datasets: for dataset in datasets:
tf.compat.v1.logging.info("Preparing dataset %s", dataset) logging.info("Preparing dataset %s", dataset)
dataset_dir = os.path.join(directory, dataset) dataset_dir = os.path.join(directory, dataset)
download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset]) download_and_extract(dataset_dir, LIBRI_SPEECH_URLS[dataset])
convert_audio_and_split_transcript( convert_audio_and_split_transcript(
...@@ -202,7 +203,7 @@ def main(_): ...@@ -202,7 +203,7 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
define_data_download_flags() define_data_download_flags()
FLAGS = absl_flags.FLAGS FLAGS = absl_flags.FLAGS
absl_app.run(main) absl_app.run(main)
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
...@@ -225,7 +226,7 @@ def run_deep_speech(_): ...@@ -225,7 +226,7 @@ def run_deep_speech(_):
"""Run deep speech training and eval loop.""" """Run deep speech training and eval loop."""
tf.compat.v1.set_random_seed(flags_obj.seed) tf.compat.v1.set_random_seed(flags_obj.seed)
# Data preprocessing # Data preprocessing
tf.compat.v1.logging.info("Data preprocessing...") logging.info("Data preprocessing...")
train_speech_dataset = generate_dataset(flags_obj.train_data_dir) train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir) eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)
...@@ -271,7 +272,7 @@ def run_deep_speech(_): ...@@ -271,7 +272,7 @@ def run_deep_speech(_):
total_training_cycle = (flags_obj.train_epochs // total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals) flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle): for cycle_index in range(total_training_cycle):
tf.compat.v1.logging.info("Starting a training cycle: %d/%d", logging.info("Starting a training cycle: %d/%d",
cycle_index + 1, total_training_cycle) cycle_index + 1, total_training_cycle)
# Perform batch_wise dataset shuffling # Perform batch_wise dataset shuffling
...@@ -282,7 +283,7 @@ def run_deep_speech(_): ...@@ -282,7 +283,7 @@ def run_deep_speech(_):
estimator.train(input_fn=input_fn_train) estimator.train(input_fn=input_fn_train)
# Evaluation # Evaluation
tf.compat.v1.logging.info("Starting to evaluate...") logging.info("Starting to evaluate...")
eval_results = evaluate_model( eval_results = evaluate_model(
estimator, eval_speech_dataset.speech_labels, estimator, eval_speech_dataset.speech_labels,
...@@ -290,7 +291,7 @@ def run_deep_speech(_): ...@@ -290,7 +291,7 @@ def run_deep_speech(_):
# Log the WER and CER results. # Log the WER and CER results.
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
tf.compat.v1.logging.info( logging.info(
"Iteration {}: WER = {:.2f}, CER = {:.2f}".format( "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY])) cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
...@@ -409,7 +410,7 @@ def main(_): ...@@ -409,7 +410,7 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
define_deep_speech_flags() define_deep_speech_flags()
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
absl_app.run(main) absl_app.run(main)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment