Commit 909ee1b3 authored by Taylor Robie's avatar Taylor Robie Committed by Katherine Wu
Browse files

use existing inter and intra flags, and fix wide deep test. (#5110)

parent b64f67d4
...@@ -95,7 +95,8 @@ class BaseTest(tf.test.TestCase): ...@@ -95,7 +95,8 @@ class BaseTest(tf.test.TestCase):
"""Ensure that model trains and minimizes loss.""" """Ensure that model trains and minimizes loss."""
model = census_main.build_estimator( model = census_main.build_estimator(
self.temp_dir, model_type, self.temp_dir, model_type,
model_column_fn=census_dataset.build_model_columns) model_column_fn=census_dataset.build_model_columns,
inter_op=0, intra_op=0)
# Train for 1 step to initialize model and evaluate initial loss # Train for 1 step to initialize model and evaluate initial loss
def get_input_fn(num_epochs, shuffle, batch_size): def get_input_fn(num_epochs, shuffle, batch_size):
......
...@@ -38,6 +38,10 @@ def define_wide_deep_flags(): ...@@ -38,6 +38,10 @@ def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type.""" """Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base() flags_core.define_base()
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False)
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
...@@ -48,14 +52,6 @@ def define_wide_deep_flags(): ...@@ -48,14 +52,6 @@ def define_wide_deep_flags():
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap( name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present.")) "Download data to data_dir if it is not already present."))
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help="Number of threads to use for inter-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help="Number of threads to use for intra-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")
def export_model(model, model_type, export_dir, model_column_fn): def export_model(model, model_type, export_dir, model_column_fn):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment