"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3aa8ccdceb7c42e1463ecd78d01a1b1946dab2b1"
Commit 909ee1b3 authored by Taylor Robie's avatar Taylor Robie Committed by Katherine Wu
Browse files

use existing inter and intra flags, and fix wide deep test. (#5110)

parent b64f67d4
...@@ -95,7 +95,8 @@ class BaseTest(tf.test.TestCase): ...@@ -95,7 +95,8 @@ class BaseTest(tf.test.TestCase):
"""Ensure that model trains and minimizes loss.""" """Ensure that model trains and minimizes loss."""
model = census_main.build_estimator( model = census_main.build_estimator(
self.temp_dir, model_type, self.temp_dir, model_type,
model_column_fn=census_dataset.build_model_columns) model_column_fn=census_dataset.build_model_columns,
inter_op=0, intra_op=0)
# Train for 1 step to initialize model and evaluate initial loss # Train for 1 step to initialize model and evaluate initial loss
def get_input_fn(num_epochs, shuffle, batch_size): def get_input_fn(num_epochs, shuffle, batch_size):
......
...@@ -38,6 +38,10 @@ def define_wide_deep_flags(): ...@@ -38,6 +38,10 @@ def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type.""" """Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base() flags_core.define_base()
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False)
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
...@@ -48,14 +52,6 @@ def define_wide_deep_flags(): ...@@ -48,14 +52,6 @@ def define_wide_deep_flags():
flags.DEFINE_boolean( flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap( name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present.")) "Download data to data_dir if it is not already present."))
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help="Number of threads to use for inter-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help="Number of threads to use for intra-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")
def export_model(model, model_type, export_dir, model_column_fn): def export_model(model, model_type, export_dir, model_column_fn):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment