Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6e8f1284
Commit
6e8f1284
authored
Sep 16, 2020
by
A. Unique TensorFlower
Browse files
add TPUStrategy to the support list of BackupAndRestore callback.
PiperOrigin-RevId: 332042245
parent
5dcfd2c5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
11 deletions
+21
-11
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+3
-0
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+16
-10
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+2
-1
No files found.
official/modeling/hyperparams/config_definitions.py
View file @
6e8f1284
...
...
@@ -165,12 +165,15 @@ class CallbacksConfig(base_config.Config):
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export
:
bool
=
True
enable_backup_and_restore
:
bool
=
False
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
...
...
official/vision/image_classification/callbacks.py
View file @
6e8f1284
...
...
@@ -29,16 +29,18 @@ from official.modeling import optimization
from
official.utils.misc
import
keras_utils
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
include_tensorboard
:
bool
=
True
,
time_history
:
bool
=
True
,
track_lr
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
apply_moving_average
:
bool
=
False
,
initial_step
:
int
=
0
,
batch_size
:
int
=
0
,
log_steps
:
int
=
0
,
model_dir
:
str
=
None
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
include_tensorboard
:
bool
=
True
,
time_history
:
bool
=
True
,
track_lr
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
apply_moving_average
:
bool
=
False
,
initial_step
:
int
=
0
,
batch_size
:
int
=
0
,
log_steps
:
int
=
0
,
model_dir
:
str
=
None
,
backup_and_restore
:
bool
=
False
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
"""Get all callbacks."""
model_dir
=
model_dir
or
''
callbacks
=
[]
...
...
@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
backup_and_restore
:
backup_dir
=
os
.
path
.
join
(
model_dir
,
'tmp'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
experimental
.
BackupAndRestore
(
backup_dir
))
if
include_tensorboard
:
callbacks
.
append
(
CustomTensorBoard
(
...
...
official/vision/image_classification/classifier_trainer.py
View file @
6e8f1284
...
...
@@ -368,7 +368,8 @@ def train_and_eval(
initial_step
=
initial_epoch
*
train_steps
,
batch_size
=
train_builder
.
global_batch_size
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
model_dir
=
params
.
model_dir
)
model_dir
=
params
.
model_dir
,
backup_and_restore
=
params
.
train
.
callbacks
.
enable_backup_and_restore
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment