Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9e8d7643
"vscode:/vscode.git/clone" did not exist on "3ea469b8760e96aed3ecd251aeac4eba1c079a22"
Commit
9e8d7643
authored
Sep 15, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 331861386
parent
a121a29f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
8 deletions
+24
-8
official/core/train_lib.py
official/core/train_lib.py
+7
-2
official/core/train_lib_test.py
official/core/train_lib_test.py
+14
-3
official/core/train_utils.py
official/core/train_utils.py
+1
-1
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+2
-2
No files found.
official/core/train_lib.py
View file @
9e8d7643
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""TFM common training driver library."""
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
copy
import
copy
import
json
import
json
import
os
import
os
...
@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
...
@@ -219,9 +219,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
elif
mode
==
'eval'
:
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
)
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
...
...
official/core/train_lib_test.py
View file @
9e8d7643
...
@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -49,6 +49,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
'train_steps'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'optimizer_config'
:
{
'optimizer_config'
:
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
...
@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -97,9 +98,19 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEmpty
(
logs
)
self
.
assertEmpty
(
logs
)
self
.
assertNotEmpty
(
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'params.yaml'
)))
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'params.yaml'
)))
if
flag_mode
!=
'eval'
:
if
flag_mode
==
'eval'
:
return
self
.
assertNotEmpty
(
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
# Tests continuous evaluation.
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'continuous_eval'
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
print
(
logs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/core/train_utils.py
View file @
9e8d7643
...
@@ -38,7 +38,7 @@ def create_trainer(
...
@@ -38,7 +38,7 @@ def create_trainer(
model_dir
:
str
,
model_dir
:
str
,
train
:
bool
,
train
:
bool
,
evaluate
:
bool
,
evaluate
:
bool
,
checkpoint_exporter
:
Any
=
None
):
checkpoint_exporter
:
Any
=
None
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
"""Create trainer."""
del
model_dir
del
model_dir
logging
.
info
(
'Running default trainer.'
)
logging
.
info
(
'Running default trainer.'
)
...
...
official/modeling/hyperparams/config_definitions.py
View file @
9e8d7643
...
@@ -189,7 +189,7 @@ class TrainerConfig(base_config.Config):
...
@@ -189,7 +189,7 @@ class TrainerConfig(base_config.Config):
continuous_eval_timeout: maximum number of seconds to wait between
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
is only used continuous_train_and_eval and continuous_eval modes. Default
value is
24
hrs.
value is
1
hrs.
train_steps: number of train steps.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
is used.
...
@@ -218,7 +218,7 @@ class TrainerConfig(base_config.Config):
...
@@ -218,7 +218,7 @@ class TrainerConfig(base_config.Config):
checkpoint_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
# Checkpoint manager.
# Checkpoint manager.
max_to_keep
:
int
=
5
max_to_keep
:
int
=
5
continuous_eval_timeout
:
int
=
24
*
60
*
60
continuous_eval_timeout
:
int
=
60
*
60
# Train/Eval routines.
# Train/Eval routines.
train_steps
:
int
=
0
train_steps
:
int
=
0
validation_steps
:
Optional
[
int
]
=
None
validation_steps
:
Optional
[
int
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment