Unverified Commit c87c3965 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Update deep_speech.py (#8694)

parent 53fd242a
......@@ -28,8 +28,6 @@ import data.dataset as dataset
import decoder
import deep_speech_model
from official.utils.flags import core as flags_core
from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
......@@ -276,16 +274,6 @@ def run_deep_speech(_):
"use_bias": flags_obj.use_bias
}
dataset_name = "LibriSpeech"
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info("deep_speech", dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
per_replica_batch_size = per_device_batch_size(flags_obj.batch_size, num_gpus)
def input_fn_train():
......@@ -307,7 +295,7 @@ def run_deep_speech(_):
train_speech_dataset.entries, cycle_index, flags_obj.sortagrad,
flags_obj.batch_size)
estimator.train(input_fn=input_fn_train, hooks=train_hooks)
estimator.train(input_fn=input_fn_train)
# Evaluation
tf.logging.info("Starting to evaluate...")
......@@ -433,8 +421,7 @@ def define_deep_speech_flags():
def main(_):
with logger.benchmark_context(flags_obj):
run_deep_speech(flags_obj)
run_deep_speech(flags_obj)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment