Unverified Commit cdd61f61 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 0225b135 a9322830
...@@ -13,20 +13,21 @@ ...@@ -13,20 +13,21 @@
# limitations under the License. # limitations under the License.
"""Panoptic MaskRCNN task definition.""" """Panoptic MaskRCNN task definition."""
from typing import Any, List, Mapping, Optional, Tuple, Dict from typing import Any, Dict, List, Mapping, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn from official.common import dataset_fn
from official.core import task_factory from official.core import task_factory
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.evaluation import panoptic_quality_evaluator
from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg
from official.vision.beta.projects.panoptic_maskrcnn.dataloaders import panoptic_maskrcnn_input from official.vision.beta.projects.panoptic_maskrcnn.dataloaders import panoptic_maskrcnn_input
from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
from official.vision.beta.tasks import maskrcnn from official.vision.dataloaders import input_reader_factory
from official.vision.evaluation import panoptic_quality_evaluator
from official.vision.evaluation import segmentation_metrics
from official.vision.losses import segmentation_losses
from official.vision.tasks import maskrcnn
@task_factory.register_task_cls(exp_cfg.PanopticMaskRCNNTask) @task_factory.register_task_cls(exp_cfg.PanopticMaskRCNNTask)
......
...@@ -18,10 +18,10 @@ import os ...@@ -18,10 +18,10 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.vision.beta.configs import decoders as decoder_cfg
from official.vision.beta.configs import semantic_segmentation as segmentation_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as cfg
from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn
from official.vision.configs import decoders as decoder_cfg
from official.vision.configs import semantic_segmentation as segmentation_cfg
class PanopticMaskRCNNTaskTest(tf.test.TestCase, parameterized.TestCase): class PanopticMaskRCNNTaskTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from absl import app from absl import app
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.vision.beta import train from official.vision import train
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as cfg # pylint: disable=unused-import from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as cfg # pylint: disable=unused-import
from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn as task # pylint: disable=unused-import from official.vision.beta.projects.panoptic_maskrcnn.tasks import panoptic_maskrcnn as task # pylint: disable=unused-import
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""All necessary imports for registration.""" """All necessary imports for registration."""
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.vision import registry_imports
from official.vision.beta.projects.simclr.configs import simclr from official.vision.beta.projects.simclr.configs import simclr
from official.vision.beta.projects.simclr.losses import contrastive_losses from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
......
...@@ -20,10 +20,10 @@ from typing import List, Tuple ...@@ -20,10 +20,10 @@ from typing import List, Tuple
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.multitask import configs as multitask_configs from official.modeling.multitask import configs as multitask_configs
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.configs import simclr as simclr_configs from official.vision.beta.projects.simclr.configs import simclr as simclr_configs
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg ...@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -40,11 +40,11 @@ from typing import List ...@@ -40,11 +40,11 @@ from typing import List
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops
class Decoder(decoder.Decoder): class Decoder(decoder.Decoder):
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
"""Multi-task image multi-taskSimCLR model definition.""" """Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text from typing import Dict, Text
from absl import logging
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
PROJECTION_OUTPUT_KEY = 'projection_outputs' PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs' SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
...@@ -110,8 +110,9 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel): ...@@ -110,8 +110,9 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
pretrained_items = dict( pretrained_items = dict(
backbone=self._backbone, projection_head=self._projection_head) backbone=self._backbone, projection_head=self._projection_head)
else: else:
assert ("Only 'backbone_projection' or 'backbone' can be used to " raise ValueError(
'initialize the model.') "Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
ckpt = tf.train.Checkpoint(**pretrained_items) ckpt = tf.train.Checkpoint(**pretrained_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
......
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
"""Test for SimCLR model.""" """Test for SimCLR model."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase): class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -36,12 +36,12 @@ from official.core import task_factory ...@@ -36,12 +36,12 @@ from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.modeling import performance from official.modeling import performance
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
from official.vision.beta.projects.simclr.dataloaders import simclr_input from official.vision.beta.projects.simclr.dataloaders import simclr_input
from official.vision.beta.projects.simclr.heads import simclr_head from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.losses import contrastive_losses from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones
OptimizationConfig = optimization.OptimizationConfig OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
...@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task):
def validation_step(self, inputs, model, metrics=None): def validation_step(self, inputs, model, metrics=None):
if self.task_config.model.supervised_head is None: if self.task_config.model.supervised_head is None:
assert 'Skipping eval during pretraining without supervised head.' raise ValueError(
'Skipping eval during pretraining without supervised head.')
features, labels = inputs features, labels = inputs
if self.task_config.evaluation.one_hot: if self.task_config.evaluation.one_hot:
...@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task):
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
# If the checkpoint is from pretraining, reset the following parameters # If the checkpoint is from pretraining, reset the following parameters
model.backbone_trainable = self.task_config.model.backbone_trainable model.backbone_trainable = self.task_config.model.backbone_trainable
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
'use_depthwise': self._use_depthwise, 'use_depthwise': self._use_depthwise,
'use_residual': self._use_residual, 'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints
} }
base_config = super(InvertedBottleneckBlock, self).get_config() base_config = super(InvertedBottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment