Commit 96faaea8 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 433133947
parent d6e3a60f
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""All necessary imports for registration.""" """All necessary imports for registration."""
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.vision import registry_imports
from official.vision.beta.projects.simclr.configs import simclr from official.vision.beta.projects.simclr.configs import simclr
from official.vision.beta.projects.simclr.losses import contrastive_losses from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
......
...@@ -20,10 +20,10 @@ from typing import List, Tuple ...@@ -20,10 +20,10 @@ from typing import List, Tuple
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.multitask import configs as multitask_configs from official.modeling.multitask import configs as multitask_configs
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.configs import simclr as simclr_configs from official.vision.beta.projects.simclr.configs import simclr as simclr_configs
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg ...@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -40,11 +40,11 @@ from typing import List ...@@ -40,11 +40,11 @@ from typing import List
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops
class Decoder(decoder.Decoder): class Decoder(decoder.Decoder):
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
"""Multi-task image multi-taskSimCLR model definition.""" """Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text from typing import Dict, Text
from absl import logging
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
PROJECTION_OUTPUT_KEY = 'projection_outputs' PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs' SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
...@@ -110,8 +110,9 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel): ...@@ -110,8 +110,9 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
pretrained_items = dict( pretrained_items = dict(
backbone=self._backbone, projection_head=self._projection_head) backbone=self._backbone, projection_head=self._projection_head)
else: else:
assert ("Only 'backbone_projection' or 'backbone' can be used to " raise ValueError(
'initialize the model.') "Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
ckpt = tf.train.Checkpoint(**pretrained_items) ckpt = tf.train.Checkpoint(**pretrained_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
......
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
"""Test for SimCLR model.""" """Test for SimCLR model."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase): class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -36,12 +36,12 @@ from official.core import task_factory ...@@ -36,12 +36,12 @@ from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.modeling import performance from official.modeling import performance
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
from official.vision.beta.projects.simclr.dataloaders import simclr_input from official.vision.beta.projects.simclr.dataloaders import simclr_input
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.losses import contrastive_losses from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
OptimizationConfig = optimization.OptimizationConfig OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
...@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task):
def validation_step(self, inputs, model, metrics=None): def validation_step(self, inputs, model, metrics=None):
if self.task_config.model.supervised_head is None: if self.task_config.model.supervised_head is None:
assert 'Skipping eval during pretraining without supervised head.' raise ValueError(
'Skipping eval during pretraining without supervised head.')
features, labels = inputs features, labels = inputs
if self.task_config.evaluation.one_hot: if self.task_config.evaluation.one_hot:
...@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
# If the checkpoint is from pretraining, reset the following parameters # If the checkpoint is from pretraining, reset the following parameters
model.backbone_trainable = self.task_config.model.backbone_trainable model.backbone_trainable = self.task_config.model.backbone_trainable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment