Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7e810001
Commit
7e810001
authored
Apr 09, 2018
by
Zhichao Lu
Committed by
pkulzc
Apr 13, 2018
Browse files
Access TPUEstimator and CrossShardOptimizer from tf namesspace.
PiperOrigin-RevId: 192226678
parent
b0c5c3b5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
research/object_detection/model_lib.py
research/object_detection/model_lib.py
+2
-4
research/object_detection/protos/train.proto
research/object_detection/protos/train.proto
+3
-1
No files found.
research/object_detection/model_lib.py
View file @
7e810001
...
@@ -22,8 +22,6 @@ import functools
...
@@ -22,8 +22,6 @@ import functools
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.tpu.python.tpu
import
tpu_estimator
from
tensorflow.contrib.tpu.python.tpu
import
tpu_optimizer
from
object_detection
import
eval_util
from
object_detection
import
eval_util
from
object_detection
import
inputs
from
object_detection
import
inputs
from
object_detection.builders
import
model_builder
from
object_detection.builders
import
model_builder
...
@@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
...
@@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
use_tpu
:
if
use_tpu
:
training_optimizer
=
t
pu_optimizer
.
CrossShardOptimizer
(
training_optimizer
=
t
f
.
contrib
.
tpu
.
CrossShardOptimizer
(
training_optimizer
)
training_optimizer
)
# Optionally freeze some layers by setting their gradients to be zero.
# Optionally freeze some layers by setting their gradients to be zero.
...
@@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config,
...
@@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config,
model_fn
=
model_fn_creator
(
detection_model_fn
,
configs
,
hparams
,
use_tpu
)
model_fn
=
model_fn_creator
(
detection_model_fn
,
configs
,
hparams
,
use_tpu
)
if
use_tpu_estimator
:
if
use_tpu_estimator
:
estimator
=
t
pu_estimator
.
TPUEstimator
(
estimator
=
t
f
.
contrib
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_fn
=
model_fn
,
train_batch_size
=
train_config
.
batch_size
,
train_batch_size
=
train_config
.
batch_size
,
# For each core, only batch size 1 is supported for eval.
# For each core, only batch size 1 is supported for eval.
...
...
research/object_detection/protos/train.proto
View file @
7e810001
...
@@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto";
...
@@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py).
// Message for configuring DetectionModel training jobs (train.py).
message
TrainConfig
{
message
TrainConfig
{
// Input queue batch size.
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
// `batch_size` / number of cores (or `batch_size` / number of GPUs).
optional
uint32
batch_size
=
1
[
default
=
32
];
optional
uint32
batch_size
=
1
[
default
=
32
];
// Data augmentation options.
// Data augmentation options.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment