Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
909ee1b3
Commit
909ee1b3
authored
Aug 16, 2018
by
Taylor Robie
Committed by
Katherine Wu
Aug 16, 2018
Browse files
use existing inter and intra flags, and fix wide deep test. (#5110)
parent
b64f67d4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
9 deletions
+6
-9
official/wide_deep/census_test.py
official/wide_deep/census_test.py
+2
-1
official/wide_deep/wide_deep_run_loop.py
official/wide_deep/wide_deep_run_loop.py
+4
-8
No files found.
official/wide_deep/census_test.py
View file @
909ee1b3
...
...
@@ -95,7 +95,8 @@ class BaseTest(tf.test.TestCase):
"""Ensure that model trains and minimizes loss."""
model
=
census_main
.
build_estimator
(
self
.
temp_dir
,
model_type
,
model_column_fn
=
census_dataset
.
build_model_columns
)
model_column_fn
=
census_dataset
.
build_model_columns
,
inter_op
=
0
,
intra_op
=
0
)
# Train for 1 step to initialize model and evaluate initial loss
def
get_input_fn
(
num_epochs
,
shuffle
,
batch_size
):
...
...
official/wide_deep/wide_deep_run_loop.py
View file @
909ee1b3
...
...
@@ -38,6 +38,10 @@ def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core
.
define_base
()
flags_core
.
define_benchmark
()
flags_core
.
define_performance
(
num_parallel_calls
=
False
,
inter_op
=
True
,
intra_op
=
True
,
synthetic_data
=
False
,
max_train_steps
=
False
,
dtype
=
False
,
all_reduce_alg
=
False
)
flags
.
adopt_module_key_flags
(
flags_core
)
...
...
@@ -48,14 +52,6 @@ def define_wide_deep_flags():
flags
.
DEFINE_boolean
(
name
=
"download_if_missing"
,
default
=
True
,
help
=
flags_core
.
help_wrap
(
"Download data to data_dir if it is not already present."
))
flags
.
DEFINE_integer
(
name
=
"inter_op_parallelism_threads"
,
short_name
=
"inter"
,
default
=
0
,
help
=
"Number of threads to use for inter-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number."
)
flags
.
DEFINE_integer
(
name
=
"intra_op_parallelism_threads"
,
short_name
=
"intra"
,
default
=
0
,
help
=
"Number of threads to use for intra-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number."
)
def
export_model
(
model
,
model_type
,
export_dir
,
model_column_fn
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment