"...resnet50_tensorflow.git" did not exist on "1efe98bb8e8d98bbffc703a90d88df15fc2ce906"
Commit a75e870b authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339917189
parent 08d3c799
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# ============================================================================== # ==============================================================================
"""Image classification configuration definition.""" """Image classification configuration definition."""
import os import os
from typing import List from typing import List, Optional
import dataclasses import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -63,6 +63,8 @@ class ImageClassificationTask(cfg.TaskConfig): ...@@ -63,6 +63,8 @@ class ImageClassificationTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
gradient_clip_norm: float = 0.0 gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
@exp_factory.register_config_factory('image_classification') @exp_factory.register_config_factory('image_classification')
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Image classification task definition.""" """Image classification task definition."""
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
...@@ -46,6 +47,30 @@ class ImageClassificationTask(base_task.Task): ...@@ -46,6 +47,30 @@ class ImageClassificationTask(base_task.Task):
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
return model return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Builds classification input.""" """Builds classification input."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment