Commit 307a8194 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Use functions from util map.

PiperOrigin-RevId: 325621239
parent 69221551
...@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel):
detections: a dictionary containing the following fields detections: a dictionary containing the following fields
detection_boxes - A tensor of shape [batch, max_detections, 4] detection_boxes - A tensor of shape [batch, max_detections, 4]
holding the predicted boxes. holding the predicted boxes.
detection_boxes_strided: A tensor of shape [batch_size, num_detections,
4] holding the predicted boxes in absolute coordinates of the
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box. the predicted score for each box.
detection_classes: An integer tensor of shape [batch, max_detections] detection_classes: An integer tensor of shape [batch, max_detections]
...@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.detection_scores: scores, fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_classes: classes, fields.DetectionResultFields.detection_classes: classes,
fields.DetectionResultFields.num_detections: num_detections, fields.DetectionResultFields.num_detections: num_detections,
'detection_boxes_strided': boxes_strided
} }
if self._kp_params_dict: if self._kp_params_dict:
......
...@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2 ...@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2
from object_detection import eval_util from object_detection import eval_util
from object_detection import inputs from object_detection import inputs
from object_detection import model_lib from object_detection import model_lib
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.protos import train_pb2 from object_detection.protos import train_pb2
...@@ -503,7 +502,7 @@ def train_loop( ...@@ -503,7 +502,7 @@ def train_loop(
# Build the model, optimizer, and training input # Build the model, optimizer, and training input
strategy = tf.compat.v2.distribute.get_strategy() strategy = tf.compat.v2.distribute.get_strategy()
with strategy.scope(): with strategy.scope():
detection_model = model_builder.build( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
def train_dataset_fn(input_context): def train_dataset_fn(input_context):
...@@ -939,7 +938,7 @@ def eval_continuously( ...@@ -939,7 +938,7 @@ def eval_continuously(
if kwargs['use_bfloat16']: if kwargs['use_bfloat16']:
tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
detection_model = model_builder.build( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# Create the inputs. # Create the inputs.
......
...@@ -244,6 +244,42 @@ message CenterNet { ...@@ -244,6 +244,42 @@ message CenterNet {
optional ClassificationLoss classification_loss = 5; optional ClassificationLoss classification_loss = 5;
} }
optional TrackEstimation track_estimation_task = 10; optional TrackEstimation track_estimation_task = 10;
// BEGIN GOOGLE-INTERNAL
// Experimental Occupancy network head, use with caution.
message OccupancyNetMaskPrediction {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
// Number of points to sample within a box while training occupancy net.
optional int32 num_samples = 2 [default = 1000];
// The dimension of the occupancy embedding.
optional int32 dim = 3 [default = 256];
// Weight of occupancy embedding loss.
optional float task_loss_weight = 4 [default = 1.0];
// The stride in pixels at test time when computing the mask. THis is
// useful is computing the full mask is too expensive.
optional int32 mask_stride = 5 [default = 1];
// If set, concatenate the occupancy embedding features to (x, y)
// coordinates before feeding it to the occupancy network head.
optional bool concat_features = 6 [default = true];
// If set to a positive value, defines the length to which the embedding
// is clipped before concatenating to the (x, y) coordinates when
// concat_features=true.
optional int32 concat_clip = 7 [default = -1];
// The probability threshold to apply for masks to output a binary mask.
optional float mask_prob_threshold = 8 [default = 0.5];
}
optional OccupancyNetMaskPrediction occupancy_net_mask_prediction = 11;
// EBD GOOGLE-INTERNAL
} }
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment