Unverified Commit 054d11f5 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 14f4c9df db3eead9
This contents of `beta` folder is going to be deprecated soon and most of the
content has been moved to[official/vision]
(https://github.com/tensorflow/models/tree/master/official/vision).
Contents of this `beta` folder is going to be deprecated soon and most of the
content has been moved to[official/vision](https://github.com/tensorflow/models/tree/master/official/vision).
......@@ -18,7 +18,7 @@ from typing import List, Optional
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_blocks
def _apply_blocks(inputs, blocks):
......
......@@ -21,8 +21,8 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks
from official.vision.modeling.layers import nn_blocks
class HourglassBlockPyTorch(tf.keras.layers.Layer):
......
......@@ -25,9 +25,9 @@ from typing import Any, Mapping
import tensorflow as tf
from official.vision.beta.ops import box_ops
from official.vision.beta.projects.centernet.ops import loss_ops
from official.vision.beta.projects.centernet.ops import nms_ops
from official.vision.ops import box_ops
class CenterNetDetectionGenerator(tf.keras.layers.Layer):
......
......@@ -14,11 +14,11 @@
"""Functions used to load the ODAPI CenterNet checkpoint."""
from official.vision.beta.modeling.backbones import mobilenet
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks
from official.vision.beta.projects.centernet.utils.checkpoints import config_classes
from official.vision.beta.projects.centernet.utils.checkpoints import config_data
from official.vision.modeling.backbones import mobilenet
from official.vision.modeling.layers import nn_blocks
Conv2DBNCFG = config_classes.Conv2DBNCFG
HeadConvCFG = config_classes.HeadConvCFG
......
......@@ -263,9 +263,11 @@ if tf_version.is_tf1():
def _check_feature_extractor_exists(feature_extractor_type):
feature_extractors = set().union(*FEATURE_EXTRACTOR_MAPS)
if feature_extractor_type not in feature_extractors:
raise ValueError('{} is not supported. See `model_builder.py` for features '
'extractors compatible with different versions of '
'Tensorflow'.format(feature_extractor_type))
tf_version_str = '2' if tf_version.is_tf2() else '1'
raise ValueError(
'{} is not supported for tf version {}. See `model_builder.py` for '
'features extractors compatible with different versions of '
'Tensorflow'.format(feature_extractor_type, tf_version_str))
def _build_ssd_feature_extractor(feature_extractor_config,
......@@ -1171,7 +1173,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise,
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse,
non_max_suppression_fn=non_max_suppression_fn)
non_max_suppression_fn=non_max_suppression_fn,
output_prediction_dict=center_net_config.output_prediction_dict)
def _build_center_net_feature_extractor(feature_extractor_config, is_training):
......
......@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None,
unit_height_conv=False):
unit_height_conv=False,
output_prediction_dict=False):
"""Initializes a CenterNet model.
Args:
......@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel):
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
output_prediction_dict: If true, combines all items from the dictionary
returned by predict() function into the output of postprocess().
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel):
self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse
self._output_prediction_dict = output_prediction_dict
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
......@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.num_detections: num_detections,
}
if self._output_prediction_dict:
postprocess_dict.update(prediction_dict)
postprocess_dict['true_image_shapes'] = true_image_shapes
boxes_strided = None
if self._od_params:
boxes_strided = (
......@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes,
'detection_boxes_strided': boxes_strided
'detection_boxes_strided': boxes_strided,
})
if self._kp_params_dict:
......
......@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto";
// Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850
// Next Id = 16
// Next Id = 26
message CenterNet {
// Number of classes to predict.
optional int32 num_classes = 1;
......@@ -504,6 +504,7 @@ message CenterNet {
// within error bars.
optional bool use_only_last_stage = 24 [default = false];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......@@ -514,6 +515,10 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
optional PostProcessing post_processing = 24;
// If set, dictionary items returned by the predict() function
// are appended to the output of postprocess().
optional bool output_prediction_dict = 25 [default = false];
}
enum LossNormalize {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment