Commit 97e6a524 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 450097732
parent e83b444f
...@@ -151,24 +151,26 @@ def add_trainer( ...@@ -151,24 +151,26 @@ def add_trainer(
eval_batch_size: int, eval_batch_size: int,
learning_rate: float = 0.0001, learning_rate: float = 0.0001,
train_epochs: int = 50, train_epochs: int = 50,
num_train_examples: int = YT8M_TRAIN_EXAMPLES,
num_val_examples: int = YT8M_VAL_EXAMPLES,
): ):
"""Add and config a trainer to the experiment config.""" """Add and config a trainer to the experiment config."""
if YT8M_TRAIN_EXAMPLES <= 0: if num_train_examples <= 0:
raise ValueError('Wrong train dataset size {!r}'.format( raise ValueError('Wrong train dataset size {!r}'.format(
experiment.task.train_data)) experiment.task.train_data))
if YT8M_VAL_EXAMPLES <= 0: if num_val_examples <= 0:
raise ValueError('Wrong validation dataset size {!r}'.format( raise ValueError('Wrong validation dataset size {!r}'.format(
experiment.task.validation_data)) experiment.task.validation_data))
experiment.task.train_data.global_batch_size = train_batch_size experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = num_train_examples // train_batch_size
steps_per_loop = 500 steps_per_loop = 500
experiment.trainer = cfg.TrainerConfig( experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
summary_interval=steps_per_loop, summary_interval=steps_per_loop,
checkpoint_interval=steps_per_loop, checkpoint_interval=steps_per_loop,
train_steps=train_epochs * steps_per_epoch, train_steps=train_epochs * steps_per_epoch,
validation_steps=YT8M_VAL_EXAMPLES // eval_batch_size, validation_steps=num_val_examples // eval_batch_size,
validation_interval=steps_per_loop, validation_interval=steps_per_loop,
optimizer_config=optimization.OptimizationConfig({ optimizer_config=optimization.OptimizationConfig({
'optimizer': { 'optimizer': {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment