Unverified Commit c0a380d2 authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Update the wide_deep and transformer code for latest benchmark config. (#4246)

* Update the wide_deep code for latest benchmark config.

* Also update the transformer benchmark code.
parent b9ca525f
...@@ -433,7 +433,7 @@ def run_transformer(flags_obj): ...@@ -433,7 +433,7 @@ def run_transformer(flags_obj):
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook batch_size=params.batch_size # for ExamplesPerSecondHook
) )
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir) benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info( benchmark_logger.log_run_info(
model_name="transformer", model_name="transformer",
dataset_name="wmt_translate_ende", dataset_name="wmt_translate_ende",
......
...@@ -245,7 +245,7 @@ def run_wide_deep(flags_obj): ...@@ -245,7 +245,7 @@ def run_wide_deep(flags_obj):
'model_type': flags_obj.model_type, 'model_type': flags_obj.model_type,
} }
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir) benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params) benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '') loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment