Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
795a3f7d
Commit
795a3f7d
authored
Apr 15, 2020
by
A. Unique TensorFlower
Browse files
Fix TF2 3D Unet to standard model garden recommended style.
PiperOrigin-RevId: 306752053
parent
5741cef6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
13 deletions
+14
-13
official/nlp/nhnet/trainer.py
official/nlp/nhnet/trainer.py
+2
-13
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+12
-0
No files found.
official/nlp/nhnet/trainer.py
View file @
795a3f7d
...
...
@@ -33,6 +33,7 @@ from official.nlp.nhnet import models
from
official.nlp.nhnet
import
optimizer
from
official.nlp.transformer
import
metrics
as
transformer_metrics
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
FLAGS
=
flags
.
FLAGS
...
...
@@ -122,18 +123,6 @@ class Trainer(tf.keras.Model):
}
class
SimpleCheckpoint
(
tf
.
keras
.
callbacks
.
Callback
):
"""Keras callback to save tf.train.Checkpoints."""
def
__init__
(
self
,
checkpoint_manager
):
super
(
SimpleCheckpoint
,
self
).
__init__
()
self
.
checkpoint_manager
=
checkpoint_manager
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
step_counter
=
self
.
checkpoint_manager
.
_step_counter
.
numpy
()
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
step_counter
)
def
train
(
params
,
strategy
,
dataset
=
None
):
"""Runs training."""
...
...
@@ -168,7 +157,7 @@ def train(params, strategy, dataset=None):
if
checkpoint_manager
.
restore_or_initialize
():
logging
.
info
(
"Training restored from the checkpoints in: %s"
,
FLAGS
.
model_dir
)
checkpoint_callback
=
SimpleCheckpoint
(
checkpoint_manager
)
checkpoint_callback
=
keras_utils
.
SimpleCheckpoint
(
checkpoint_manager
)
# Trains the model.
steps_per_epoch
=
min
(
FLAGS
.
train_steps
,
FLAGS
.
checkpoint_interval
)
...
...
official/utils/misc/keras_utils.py
View file @
795a3f7d
...
...
@@ -164,6 +164,18 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
return
ProfilerCallback
(
model_dir
,
start_step
,
stop_step
,
steps_per_epoch
)
class
SimpleCheckpoint
(
tf
.
keras
.
callbacks
.
Callback
):
"""Keras callback to save tf.train.Checkpoints."""
def
__init__
(
self
,
checkpoint_manager
):
super
(
SimpleCheckpoint
,
self
).
__init__
()
self
.
checkpoint_manager
=
checkpoint_manager
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
step_counter
=
self
.
checkpoint_manager
.
_step_counter
.
numpy
()
# pylint: disable=protected-access
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
step_counter
)
class
ProfilerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Save profiles in specified step range to log directory."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment