Unverified Commit 09d9656f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents ac671306 49a5706c
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
"""Config template to train Retinanet.""" """Config template to train Retinanet."""
from official.legacy.detection.configs import base_config
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
# pylint: disable=line-too-long # pylint: disable=line-too-long
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
"""Config to train shapemask on COCO.""" """Config to train shapemask on COCO."""
from official.legacy.detection.configs import base_config
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
......
...@@ -21,8 +21,8 @@ from __future__ import print_function ...@@ -21,8 +21,8 @@ from __future__ import print_function
import collections import collections
import tensorflow as tf import tensorflow as tf
from official.legacy.detection.utils import box_utils
from official.vision.beta.ops import iou_similarity from official.vision.beta.ops import iou_similarity
from official.vision.detection.utils import box_utils
from official.vision.utils.object_detection import argmax_matcher from official.vision.utils.object_detection import argmax_matcher
from official.vision.utils.object_detection import balanced_positive_negative_sampler from official.vision.utils.object_detection import balanced_positive_negative_sampler
from official.vision.utils.object_detection import box_list from official.vision.utils.object_detection import box_list
......
...@@ -18,10 +18,10 @@ from __future__ import absolute_import ...@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from official.vision.detection.dataloader import maskrcnn_parser from official.legacy.detection.dataloader import maskrcnn_parser
from official.vision.detection.dataloader import olnmask_parser from official.legacy.detection.dataloader import olnmask_parser
from official.vision.detection.dataloader import retinanet_parser from official.legacy.detection.dataloader import retinanet_parser
from official.vision.detection.dataloader import shapemask_parser from official.legacy.detection.dataloader import shapemask_parser
def parser_generator(params, mode): def parser_generator(params, mode):
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import Optional, Text
import tensorflow as tf import tensorflow as tf
from official.legacy.detection.dataloader import factory
from typing import Text, Optional from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.vision.detection.dataloader import factory
from official.vision.detection.dataloader import mode_keys as ModeKeys
class InputFn(object): class InputFn(object):
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
import tensorflow as tf import tensorflow as tf
from official.vision.detection.dataloader import anchor from official.legacy.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.dataloader import tf_example_decoder from official.legacy.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import box_utils from official.legacy.detection.utils import box_utils
from official.vision.detection.utils import dataloader_utils from official.legacy.detection.utils import dataloader_utils
from official.vision.detection.utils import input_utils from official.legacy.detection.utils import input_utils
class Parser(object): class Parser(object):
...@@ -345,13 +345,9 @@ class Parser(object): ...@@ -345,13 +345,9 @@ class Parser(object):
image = tf.cast(image, dtype=tf.bfloat16) image = tf.cast(image, dtype=tf.bfloat16)
# Compute Anchor boxes. # Compute Anchor boxes.
input_anchor = anchor.Anchor( _ = anchor.Anchor(self._min_level, self._max_level, self._num_scales,
self._min_level, self._aspect_ratios, self._anchor_size,
self._max_level, (image_height, image_width))
self._num_scales,
self._aspect_ratios,
self._anchor_size,
(image_height, image_width))
labels = { labels = {
'image_info': image_info, 'image_info': image_info,
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
import tensorflow as tf import tensorflow as tf
from official.vision.detection.dataloader import anchor from official.legacy.detection.dataloader import anchor
from official.vision.detection.dataloader.maskrcnn_parser import Parser as MaskrcnnParser from official.legacy.detection.dataloader.maskrcnn_parser import Parser as MaskrcnnParser
from official.vision.detection.utils import box_utils from official.legacy.detection.utils import box_utils
from official.vision.detection.utils import class_utils from official.legacy.detection.utils import class_utils
from official.vision.detection.utils import input_utils from official.legacy.detection.utils import input_utils
class Parser(MaskrcnnParser): class Parser(MaskrcnnParser):
......
...@@ -23,11 +23,11 @@ Focal Loss for Dense Object Detection. arXiv:1708.02002 ...@@ -23,11 +23,11 @@ Focal Loss for Dense Object Detection. arXiv:1708.02002
import tensorflow as tf import tensorflow as tf
from official.vision.detection.dataloader import anchor from official.legacy.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.dataloader import tf_example_decoder from official.legacy.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import box_utils from official.legacy.detection.utils import box_utils
from official.vision.detection.utils import input_utils from official.legacy.detection.utils import input_utils
def process_source_id(source_id): def process_source_id(source_id):
......
...@@ -23,13 +23,13 @@ arXiv:1904.03239. ...@@ -23,13 +23,13 @@ arXiv:1904.03239.
""" """
import tensorflow as tf import tensorflow as tf
from official.vision.detection.dataloader import anchor from official.legacy.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.dataloader import tf_example_decoder from official.legacy.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import box_utils from official.legacy.detection.utils import box_utils
from official.vision.detection.utils import class_utils from official.legacy.detection.utils import class_utils
from official.vision.detection.utils import dataloader_utils from official.legacy.detection.utils import dataloader_utils
from official.vision.detection.utils import input_utils from official.legacy.detection.utils import input_utils
def pad_to_size(input_tensor, size): def pad_to_size(input_tensor, size):
......
...@@ -40,11 +40,12 @@ from pycocotools import cocoeval ...@@ -40,11 +40,12 @@ from pycocotools import cocoeval
import six import six
import tensorflow as tf import tensorflow as tf
from official.vision.detection.evaluation import coco_utils from official.legacy.detection.evaluation import coco_utils
from official.vision.detection.utils import class_utils from official.legacy.detection.utils import class_utils
class MetricWrapper(object): class MetricWrapper(object):
"""Metric Wrapper of the COCO evaluator."""
# This is only a wrapper for COCO metric and works on for numpy array. So it # This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric. # doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
...@@ -52,6 +53,7 @@ class MetricWrapper(object): ...@@ -52,6 +53,7 @@ class MetricWrapper(object):
self._evaluator = evaluator self._evaluator = evaluator
def update_state(self, y_true, y_pred): def update_state(self, y_true, y_pred):
"""Update internal states."""
labels = tf.nest.map_structure(lambda x: x.numpy(), y_true) labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred) outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
groundtruths = {} groundtruths = {}
......
...@@ -29,9 +29,9 @@ from pycocotools import mask as mask_api ...@@ -29,9 +29,9 @@ from pycocotools import mask as mask_api
import six import six
import tensorflow as tf import tensorflow as tf
from official.vision.detection.dataloader import tf_example_decoder from official.legacy.detection.dataloader import tf_example_decoder
from official.vision.detection.utils import box_utils from official.legacy.detection.utils import box_utils
from official.vision.detection.utils import mask_utils from official.legacy.detection.utils import mask_utils
class COCOWrapper(coco.COCO): class COCOWrapper(coco.COCO):
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from official.vision.detection.evaluation import coco_evaluator from official.legacy.detection.evaluation import coco_evaluator
def evaluator_generator(params): def evaluator_generator(params):
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision.detection.executor import distributed_executor as executor from official.legacy.detection.executor import distributed_executor as executor
from official.vision.utils.object_detection import visualization_utils from official.vision.utils.object_detection import visualization_utils
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -27,10 +28,9 @@ import numpy as np ...@@ -27,10 +28,9 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported # pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from official.common import distribute_utils
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.common import distribute_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -23,16 +23,16 @@ from absl import logging ...@@ -23,16 +23,16 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils from official.common import distribute_utils
from official.legacy.detection.configs import factory as config_factory
from official.legacy.detection.dataloader import input_reader
from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.legacy.detection.executor import distributed_executor as executor
from official.legacy.detection.executor.detection_executor import DetectionDistributedExecutor
from official.legacy.detection.modeling import factory as model_factory
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor import distributed_executor as executor
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
flags_core.define_log_steps() flags_core.define_log_steps()
...@@ -173,6 +173,7 @@ def run_executor(params, ...@@ -173,6 +173,7 @@ def run_executor(params,
def run(callbacks=None): def run(callbacks=None):
"""Runs the experiment."""
keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
params = config_factory.config_generator(FLAGS.model) params = config_factory.config_generator(FLAGS.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment