".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "6c856b4f3a4e63a25f5adc3388bf79ac2a6e4f72"
Commit 4c57e52b authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312391897
parent 02c7112e
...@@ -341,6 +341,7 @@ def train_and_eval( ...@@ -341,6 +341,7 @@ def train_and_eval(
metrics_map = _get_metrics(one_hot) metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics] metrics = [metrics_map[metric] for metric in params.train.metrics]
steps_per_loop = train_steps if params.train.set_epoch_loop else 1
if one_hot: if one_hot:
loss_obj = tf.keras.losses.CategoricalCrossentropy( loss_obj = tf.keras.losses.CategoricalCrossentropy(
...@@ -350,7 +351,7 @@ def train_and_eval( ...@@ -350,7 +351,7 @@ def train_and_eval(
model.compile(optimizer=optimizer, model.compile(optimizer=optimizer,
loss=loss_obj, loss=loss_obj,
metrics=metrics, metrics=metrics,
experimental_steps_per_execution=params.train.steps_per_loop) experimental_steps_per_execution=steps_per_loop)
initial_epoch = 0 initial_epoch = 0
if params.train.resume_checkpoint: if params.train.resume_checkpoint:
......
...@@ -82,8 +82,10 @@ class TrainConfig(base_config.Config): ...@@ -82,8 +82,10 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig. callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig. metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig. tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function` set_epoch_loop: Whether or not to set `experimental_steps_per_execution` to
call during training, which can increase training speed. equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end
TPU training time.
""" """
resume_checkpoint: bool = None resume_checkpoint: bool = None
...@@ -93,7 +95,7 @@ class TrainConfig(base_config.Config): ...@@ -93,7 +95,7 @@ class TrainConfig(base_config.Config):
metrics: MetricsConfig = None metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig() tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig() time_history: TimeHistoryConfig = TimeHistoryConfig()
steps_per_loop: int = None set_epoch_loop: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -55,7 +55,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -55,7 +55,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False), write_model_weights=False),
steps_per_loop=1) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
...@@ -88,7 +88,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -88,7 +88,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False), write_model_weights=False),
steps_per_loop=1) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1,
steps=None) steps=None)
......
...@@ -47,7 +47,6 @@ model: ...@@ -47,7 +47,6 @@ model:
train: train:
resume_checkpoint: True resume_checkpoint: True
epochs: 500 epochs: 500
# 313 * batch_size = dataset_size set_epoch_loop: True
steps_per_loop: 313
evaluation: evaluation:
epochs_between_evals: 1 epochs_between_evals: 1
...@@ -46,7 +46,6 @@ model: ...@@ -46,7 +46,6 @@ model:
train: train:
resume_checkpoint: True resume_checkpoint: True
epochs: 500 epochs: 500
# 313 * batch_size = dataset_size set_epoch_loop: True
steps_per_loop: 313
evaluation: evaluation:
epochs_between_evals: 1 epochs_between_evals: 1
...@@ -42,7 +42,7 @@ model: ...@@ -42,7 +42,7 @@ model:
decay: 0.9 decay: 0.9
epsilon: 0.001 epsilon: 0.001
moving_average_decay: 0. moving_average_decay: 0.
lookahead: false lookahead: False
learning_rate: learning_rate:
name: 'piecewise_constant_with_warmup' name: 'piecewise_constant_with_warmup'
loss: loss:
...@@ -52,7 +52,6 @@ train: ...@@ -52,7 +52,6 @@ train:
enable_checkpoint_and_export: True enable_checkpoint_and_export: True
resume_checkpoint: True resume_checkpoint: True
epochs: 90 epochs: 90
# 313 * batch_size = dataset_size set_epoch_loop: True
steps_per_loop: 313
evaluation: evaluation:
epochs_between_evals: 1 epochs_between_evals: 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment