Unverified Commit c87c3965 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Update deep_speech.py (#8694)

parent 53fd242a
...@@ -28,8 +28,6 @@ import data.dataset as dataset ...@@ -28,8 +28,6 @@ import data.dataset as dataset
import decoder import decoder
import deep_speech_model import deep_speech_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -276,16 +274,6 @@ def run_deep_speech(_): ...@@ -276,16 +274,6 @@ def run_deep_speech(_):
"use_bias": flags_obj.use_bias "use_bias": flags_obj.use_bias
} }
dataset_name = "LibriSpeech"
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info("deep_speech", dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
per_replica_batch_size = per_device_batch_size(flags_obj.batch_size, num_gpus) per_replica_batch_size = per_device_batch_size(flags_obj.batch_size, num_gpus)
def input_fn_train(): def input_fn_train():
...@@ -307,7 +295,7 @@ def run_deep_speech(_): ...@@ -307,7 +295,7 @@ def run_deep_speech(_):
train_speech_dataset.entries, cycle_index, flags_obj.sortagrad, train_speech_dataset.entries, cycle_index, flags_obj.sortagrad,
flags_obj.batch_size) flags_obj.batch_size)
estimator.train(input_fn=input_fn_train, hooks=train_hooks) estimator.train(input_fn=input_fn_train)
# Evaluation # Evaluation
tf.logging.info("Starting to evaluate...") tf.logging.info("Starting to evaluate...")
...@@ -433,8 +421,7 @@ def define_deep_speech_flags(): ...@@ -433,8 +421,7 @@ def define_deep_speech_flags():
def main(_): def main(_):
with logger.benchmark_context(flags_obj): run_deep_speech(flags_obj)
run_deep_speech(flags_obj)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment