Commit e2b671b5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410840853
parent 9c0d7874
......@@ -28,7 +28,7 @@ from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
class PruningAction:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
......@@ -66,7 +66,7 @@ class PruningActions:
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
output: The train output.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
......@@ -81,8 +81,11 @@ class EMACheckpointing:
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
def __init__(self,
export_dir: str,
optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint,
max_to_keep: int = 1):
"""Initializes the instance.
Args:
......@@ -99,8 +102,7 @@ class EMACheckpointing:
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
tf.io.gfile.makedirs(os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
......@@ -113,7 +115,7 @@ class EMACheckpointing:
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
output: The train or eval output.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
......@@ -173,8 +175,7 @@ class RecoveryCondition:
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
def get_eval_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
......@@ -202,7 +203,7 @@ def get_train_actions(
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
PruningAction(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
......
......@@ -27,14 +27,16 @@ from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
class TestModel(tf.keras.Model):
def __init__(self):
self.value = tf.Variable(0)
super().__init__()
self.value = tf.Variable(0.0)
self.dense = tf.keras.layers.Dense(2)
_ = self.dense(tf.zeros((2, 2), tf.float32))
@tf.function(input_signature=[])
def __call__(self):
return self.value
def call(self, x, training=None):
return self.value + x
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
......@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
......@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
self.assertEqual(model(0.), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
checkpoint.read(
tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
self.assertEqual(model(0.), 0)
# Raises an error for a normal optimizer.
with self.assertRaisesRegex(ValueError,
'Optimizer has to be instance of.*'):
_ = actions.EMACheckpointing(directory, tf.keras.optimizers.SGD(),
checkpoint)
@combinations.generate(
combinations.combine(
......@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(RuntimeError):
recover_condition(outputs)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_pruning(self, distribution):
with distribution.scope():
directory = self.get_temp_dir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
pruning = actions.PruningAction(directory, model, optimizer)
pruning({})
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment