Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
89def413
Commit
89def413
authored
Jun 25, 2020
by
Kaushik Shivakumar
Browse files
fix tpu training issues
parent
aca51294
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
3 deletions
+4
-3
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+2
-2
research/object_detection/model_main_tf2.py
research/object_detection/model_main_tf2.py
+2
-1
No files found.
research/object_detection/model_lib_v2.py
View file @
89def413
...
...
@@ -330,7 +330,7 @@ def load_fine_tune_checkpoint(
labels
)
strategy
=
tf
.
compat
.
v2
.
distribute
.
get_strategy
()
strategy
.
run
(
strategy
.
experimental_run_v2
(
_dummy_computation_fn
,
args
=
(
features
,
labels
,
...
...
@@ -570,7 +570,7 @@ def train_loop(
def
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
):
features
,
labels
=
data_iterator
.
next
()
per_replica_losses
=
strategy
.
run
(
per_replica_losses
=
strategy
.
experimental_run_v2
(
train_step_fn
,
args
=
(
features
,
labels
))
# TODO(anjalisridhar): explore if it is safe to remove the
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
...
...
research/object_detection/model_main_tf2.py
View file @
89def413
...
...
@@ -42,6 +42,7 @@ from object_detection import model_lib_v2
flags
.
DEFINE_string
(
'pipeline_config_path'
,
None
,
'Path to pipeline config '
'file.'
)
flags
.
DEFINE_integer
(
'num_train_steps'
,
None
,
'Number of train steps.'
)
flags
.
DEFINE_bool
(
'use_tpu'
,
False
,
'Whether to use TPUs'
)
flags
.
DEFINE_bool
(
'eval_on_train_data'
,
False
,
'Enable evaluating on train '
'data (only supported in distributed training).'
)
flags
.
DEFINE_integer
(
'sample_1_of_n_eval_examples'
,
None
,
'Will sample one of '
...
...
@@ -84,7 +85,7 @@ def main(unused_argv):
checkpoint_dir
=
FLAGS
.
checkpoint_dir
,
wait_interval
=
300
,
timeout
=
FLAGS
.
eval_timeout
)
else
:
if
tf
.
config
.
get_visible_devices
(
'TPU'
)
:
if
FLAGS
.
use_tpu
:
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
()
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment