"vscode:/vscode.git/clone" did not exist on "8c1d1b08fa3bbff5b5e77d62fc0038418b520224"
Commit 2afac6d1 authored by Rino Lee's avatar Rino Lee Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 384478988
parent bac45446
...@@ -20,12 +20,57 @@ from typing import List ...@@ -20,12 +20,57 @@ from typing import List
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import base_trainer from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
from official.modeling import optimization from official.modeling import optimization
class PruningActions:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def __init__(
self,
export_dir: str,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self._optimizer = optimizer
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
self.update_pruning_step.set_model(model)
self.update_pruning_step.on_train_begin()
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
log_dir=export_dir)
model.optimizer = optimizer
self.pruning_summaries.set_model(model)
def __call__(self, output: orbit.runner.Output):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
class EMACheckpointing: class EMACheckpointing:
"""Eval action to save checkpoint with average weights when EMA is used. """Eval action to save checkpoint with average weights when EMA is used.
...@@ -92,3 +137,20 @@ def get_eval_actions( ...@@ -92,3 +137,20 @@ def get_eval_actions(
max_to_keep=params.trainer.max_to_keep)) max_to_keep=params.trainer.max_to_keep))
return eval_actions return eval_actions
@gin.configurable
def get_train_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
return train_actions
...@@ -105,6 +105,7 @@ def run_experiment( ...@@ -105,6 +105,7 @@ def run_experiment(
(save_summary) else None, (save_summary) else None,
summary_interval=params.trainer.summary_interval if summary_interval=params.trainer.summary_interval if
(save_summary) else None, (save_summary) else None,
train_actions=actions.get_train_actions(params, trainer, model_dir),
eval_actions=actions.get_eval_actions(params, trainer, model_dir)) eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment