Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6da061c0
Commit
6da061c0
authored
Jun 16, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 16, 2020
Browse files
Internal change
PiperOrigin-RevId: 316797555
parent
e4f04456
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
4 deletions
+21
-4
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+16
-0
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+1
-1
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+4
-3
No files found.
official/modeling/hyperparams/config_definitions.py
View file @
6da061c0
...
@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config):
...
@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TrainerConfig
(
base_config
.
Config
):
class
TrainerConfig
(
base_config
.
Config
):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
train_tf_while_loop
:
bool
=
True
train_tf_while_loop
:
bool
=
True
train_tf_function
:
bool
=
True
train_tf_function
:
bool
=
True
...
@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config):
...
@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config):
summary_interval
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
max_to_keep
:
int
=
5
max_to_keep
:
int
=
5
continuous_eval_timeout
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/nlp/transformer/misc.py
View file @
6da061c0
...
@@ -218,7 +218,7 @@ def get_callbacks():
...
@@ -218,7 +218,7 @@ def get_callbacks():
time_callback
=
keras_utils
.
TimeHistory
(
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
batch_size
,
FLAGS
.
log_steps
,
FLAGS
.
log_steps
,
FLAGS
.
model_dir
if
FLAGS
.
enable_tensorboard
else
None
)
logdir
=
FLAGS
.
model_dir
if
FLAGS
.
enable_tensorboard
else
None
)
callbacks
.
append
(
time_callback
)
callbacks
.
append
(
time_callback
)
if
FLAGS
.
enable_tensorboard
:
if
FLAGS
.
enable_tensorboard
:
...
...
official/utils/misc/keras_utils.py
View file @
6da061c0
...
@@ -41,12 +41,13 @@ class BatchTimestamp(object):
...
@@ -41,12 +41,13 @@ class BatchTimestamp(object):
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback for Keras models."""
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
,
log_steps
,
logdir
=
None
):
def
__init__
(
self
,
batch_size
,
log_steps
,
initial_step
=
0
,
logdir
=
None
):
"""Callback for logging performance.
"""Callback for logging performance.
Args:
Args:
batch_size: Total batch size.
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
log_steps: Interval of steps between logging of batch level stats.
initial_step: Optional, initial step.
logdir: Optional directory to write TensorBoard summaries.
logdir: Optional directory to write TensorBoard summaries.
"""
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
...
@@ -54,8 +55,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -54,8 +55,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_steps
=
log_steps
self
.
log_steps
=
log_steps
self
.
last_log_step
=
0
self
.
last_log_step
=
initial_step
self
.
steps_before_epoch
=
0
self
.
steps_before_epoch
=
initial_step
self
.
steps_in_epoch
=
0
self
.
steps_in_epoch
=
0
self
.
start_time
=
None
self
.
start_time
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment