Commit 4c57e52b authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312391897
parent 02c7112e
......@@ -341,6 +341,7 @@ def train_and_eval(
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
steps_per_loop = train_steps if params.train.set_epoch_loop else 1
if one_hot:
loss_obj = tf.keras.losses.CategoricalCrossentropy(
......@@ -350,7 +351,7 @@ def train_and_eval(
model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
experimental_steps_per_execution=params.train.steps_per_loop)
experimental_steps_per_execution=steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
......
......@@ -82,8 +82,10 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
set_epoch_loop: Whether or not to set `experimental_steps_per_execution` to
equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end
TPU training time.
"""
resume_checkpoint: bool = None
......@@ -93,7 +95,7 @@ class TrainConfig(base_config.Config):
metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
steps_per_loop: int = None
set_epoch_loop: bool = False
@dataclasses.dataclass
......
......@@ -55,7 +55,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False),
steps_per_loop=1)
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
......@@ -88,7 +88,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False),
steps_per_loop=1)
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
......
......@@ -47,7 +47,6 @@ model:
train:
resume_checkpoint: True
epochs: 500
# 313 * batch_size = dataset_size
steps_per_loop: 313
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
......@@ -46,7 +46,6 @@ model:
train:
resume_checkpoint: True
epochs: 500
# 313 * batch_size = dataset_size
steps_per_loop: 313
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
......@@ -42,7 +42,7 @@ model:
decay: 0.9
epsilon: 0.001
moving_average_decay: 0.
lookahead: false
lookahead: False
learning_rate:
name: 'piecewise_constant_with_warmup'
loss:
......@@ -52,7 +52,6 @@ train:
enable_checkpoint_and_export: True
resume_checkpoint: True
epochs: 90
# 313 * batch_size = dataset_size
steps_per_loop: 313
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment