Commit 96faaea8 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 433133947
parent d6e3a60f
......@@ -15,7 +15,7 @@
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision import registry_imports
from official.vision.beta.projects.simclr.configs import simclr
from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model
......
......@@ -20,10 +20,10 @@ from typing import List, Tuple
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.multitask import configs as multitask_configs
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.configs import simclr as simclr_configs
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass
......
......@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass
......
......@@ -40,11 +40,11 @@ from typing import List
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops
class Decoder(decoder.Decoder):
......
......@@ -14,15 +14,15 @@
"""Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text
from absl import logging
from absl import logging
import tensorflow as tf
from official.modeling.multitask import base_model
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
......@@ -110,8 +110,9 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
pretrained_items = dict(
backbone=self._backbone, projection_head=self._projection_head)
else:
assert ("Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
raise ValueError(
"Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
ckpt = tf.train.Checkpoint(**pretrained_items)
status = ckpt.read(ckpt_dir_or_file)
......
......@@ -14,13 +14,12 @@
"""Test for SimCLR model."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -36,12 +36,12 @@ from official.core import task_factory
from official.modeling import optimization
from official.modeling import performance
from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
from official.vision.beta.projects.simclr.dataloaders import simclr_input
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
......@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task):
def validation_step(self, inputs, model, metrics=None):
if self.task_config.model.supervised_head is None:
assert 'Skipping eval during pretraining without supervised head.'
raise ValueError(
'Skipping eval during pretraining without supervised head.')
features, labels = inputs
if self.task_config.evaluation.one_hot:
......@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
# If the checkpoint is from pretraining, reset the following parameters
model.backbone_trainable = self.task_config.model.backbone_trainable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment