Commit e2b671b5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410840853
parent 9c0d7874
...@@ -28,7 +28,7 @@ from official.core import config_definitions ...@@ -28,7 +28,7 @@ from official.core import config_definitions
from official.modeling import optimization from official.modeling import optimization
class PruningActions: class PruningAction:
"""Train action to updates pruning related information. """Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log This action updates pruning steps at the end of trainig loop, and log
...@@ -66,7 +66,7 @@ class PruningActions: ...@@ -66,7 +66,7 @@ class PruningActions:
"""Update pruning step and log pruning summaries. """Update pruning step and log pruning summaries.
Args: Args:
output: The train output to test. output: The train output.
""" """
self.update_pruning_step.on_epoch_end(batch=None) self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None) self.pruning_summaries.on_epoch_begin(epoch=None)
...@@ -81,8 +81,11 @@ class EMACheckpointing: ...@@ -81,8 +81,11 @@ class EMACheckpointing:
than training. than training.
""" """
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer, def __init__(self,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1): export_dir: str,
optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint,
max_to_keep: int = 1):
"""Initializes the instance. """Initializes the instance.
Args: Args:
...@@ -99,8 +102,7 @@ class EMACheckpointing: ...@@ -99,8 +102,7 @@ class EMACheckpointing:
'EMACheckpointing action') 'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints') export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs( tf.io.gfile.makedirs(os.path.dirname(export_dir))
os.path.dirname(export_dir))
self._optimizer = optimizer self._optimizer = optimizer
self._checkpoint = checkpoint self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager( self._checkpoint_manager = tf.train.CheckpointManager(
...@@ -113,7 +115,7 @@ class EMACheckpointing: ...@@ -113,7 +115,7 @@ class EMACheckpointing:
"""Swaps model weights, and saves the checkpoint. """Swaps model weights, and saves the checkpoint.
Args: Args:
output: The train or eval output to test. output: The train or eval output.
""" """
self._optimizer.swap_weights() self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations) self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
...@@ -173,10 +175,9 @@ class RecoveryCondition: ...@@ -173,10 +175,9 @@ class RecoveryCondition:
@gin.configurable @gin.configurable
def get_eval_actions( def get_eval_actions(params: config_definitions.ExperimentConfig,
params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
trainer: base_trainer.Trainer, model_dir: str) -> List[orbit.Action]:
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer.""" """Gets eval actions for TFM trainer."""
eval_actions = [] eval_actions = []
# Adds ema checkpointing action to save the average weights under # Adds ema checkpointing action to save the average weights under
...@@ -202,7 +203,7 @@ def get_train_actions( ...@@ -202,7 +203,7 @@ def get_train_actions(
# Adds pruning callback actions. # Adds pruning callback actions.
if hasattr(params.task, 'pruning'): if hasattr(params.task, 'pruning'):
train_actions.append( train_actions.append(
PruningActions( PruningAction(
export_dir=model_dir, export_dir=model_dir,
model=trainer.model, model=trainer.model,
optimizer=trainer.optimizer)) optimizer=trainer.optimizer))
......
...@@ -27,14 +27,16 @@ from official.core import actions ...@@ -27,14 +27,16 @@ from official.core import actions
from official.modeling import optimization from official.modeling import optimization
class TestModel(tf.Module): class TestModel(tf.keras.Model):
def __init__(self): def __init__(self):
self.value = tf.Variable(0) super().__init__()
self.value = tf.Variable(0.0)
self.dense = tf.keras.layers.Dense(2)
_ = self.dense(tf.zeros((2, 2), tf.float32))
@tf.function(input_signature=[]) def call(self, x, training=None):
def __call__(self): return self.value + x
return self.value
class ActionsTest(tf.test.TestCase, parameterized.TestCase): class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.cloud_tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy,
],)) ],))
def test_ema_checkpointing(self, distribution): def test_ema_checkpointing(self, distribution):
with distribution.scope(): with distribution.scope():
...@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
model.value.assign(3) model.value.assign(3)
# Checks model.value is 3 # Checks model.value is 3
self.assertEqual(model(), 3) self.assertEqual(model(0.), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint) ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({}) ema_action({})
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints'))) tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint( checkpoint.read(
os.path.join(directory, 'ema_checkpoints'))) tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping. # Checks model.value is 0 after swapping.
self.assertEqual(model(), 0) self.assertEqual(model(0.), 0)
# Raises an error for a normal optimizer.
with self.assertRaisesRegex(ValueError,
'Optimizer has to be instance of.*'):
_ = actions.EMACheckpointing(directory, tf.keras.optimizers.SGD(),
checkpoint)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
...@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
recover_condition(outputs) recover_condition(outputs)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_pruning(self, distribution):
with distribution.scope():
directory = self.get_temp_dir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
pruning = actions.PruningAction(directory, model, optimizer)
pruning({})
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment