Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d4a4dd04
Unverified
Commit
d4a4dd04
authored
Mar 06, 2018
by
Karmel Allison
Committed by
GitHub
Mar 06, 2018
Browse files
Adding thread args back in, with allow_soft_placement (#3533)
parent
e029542a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
2 deletions
+22
-2
official/resnet/resnet.py
official/resnet/resnet.py
+22
-2
No files found.
official/resnet/resnet.py
View file @
d4a4dd04
...
@@ -609,8 +609,18 @@ def resnet_main(flags, model_function, input_function):
...
@@ -609,8 +609,18 @@ def resnet_main(flags, model_function, input_function):
model_function
,
model_function
,
loss_reduction
=
tf
.
losses
.
Reduction
.
MEAN
)
loss_reduction
=
tf
.
losses
.
Reduction
.
MEAN
)
# Set up a RunConfig to only save checkpoints once per training cycle.
# Create session config based on values of inter_op_parallelism_threads and
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
)
# intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not
# harmful for other modes.
session_config
=
tf
.
ConfigProto
(
inter_op_parallelism_threads
=
flags
.
inter_op_parallelism_threads
,
intra_op_parallelism_threads
=
flags
.
intra_op_parallelism_threads
,
allow_soft_placement
=
True
)
# Set up a RunConfig to save checkpoint and set session config.
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
,
session_config
=
session_config
)
classifier
=
tf
.
estimator
.
Estimator
(
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_function
,
model_dir
=
flags
.
model_dir
,
config
=
run_config
,
model_fn
=
model_function
,
model_dir
=
flags
.
model_dir
,
config
=
run_config
,
params
=
{
params
=
{
...
@@ -706,3 +716,13 @@ class ResnetArgParser(argparse.ArgumentParser):
...
@@ -706,3 +716,13 @@ class ResnetArgParser(argparse.ArgumentParser):
help
=
'If set, use fake data (zeroes) instead of a real dataset. '
help
=
'If set, use fake data (zeroes) instead of a real dataset. '
'This mode is useful for performance debugging, as it removes '
'This mode is useful for performance debugging, as it removes '
'input processing steps, but will not learn anything.'
)
'input processing steps, but will not learn anything.'
)
self
.
add_argument
(
'--inter_op_parallelism_threads'
,
type
=
int
,
default
=
0
,
help
=
'Number of inter_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.'
)
self
.
add_argument
(
'--intra_op_parallelism_threads'
,
type
=
int
,
default
=
0
,
help
=
'Number of intra_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment