"projects/web/vscode:/vscode.git/clone" did not exist on "3cc3f754111854aa36736173b1d3adc0d74b2208"
Commit cb62fdcc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

* Replaced the deprecated `to_int/to_float` in `center_net_meta_arch.py`...

  * Replaced the deprecated `to_int/to_float` in `center_net_meta_arch.py` with `tf.cast()` (following the warning).
  * Added `max_norm_coord` argument to CenterNet source code to allow controlling the maximum coordinate value to be considered, to allow for the wrap-around special case for RxR data.

PiperOrigin-RevId: 391319766
parent 41aafda9
......@@ -961,7 +961,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
width,
gt_boxes_list,
gt_classes_list,
gt_weights_list=None):
gt_weights_list=None,
maximum_normalized_coordinate=1.1):
"""Computes the object center heatmap target.
Args:
......@@ -977,6 +978,9 @@ class CenterNetCenterHeatmapTargetAssigner(object):
in the gt_boxes_list.
gt_weights_list: A list of float tensors with shape [num_boxes]
representing the weight of each groundtruth detection box.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
heatmap: A Tensor of size [batch_size, output_height, output_width,
......@@ -1002,7 +1006,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
tf.maximum(width // self._stride, 1),
maximum_normalized_coordinate=maximum_normalized_coordinate)
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
(y_center, x_center, boxes_height,
......
......@@ -2714,7 +2714,8 @@ class CenterNetMetaArch(model.DetectionModel):
return target_assigners
def _compute_object_center_loss(self, input_height, input_width,
object_center_predictions, per_pixel_weights):
object_center_predictions, per_pixel_weights,
maximum_normalized_coordinate=1.1):
"""Computes the object center loss.
Args:
......@@ -2726,6 +2727,9 @@ class CenterNetMetaArch(model.DetectionModel):
per_pixel_weights: A float tensor of shape [batch_size,
out_height * out_width, 1] with 1s in locations where the spatial
coordinates fall within the height and width in true_image_shapes.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A float scalar tensor representing the object center loss per instance.
......@@ -2752,7 +2756,8 @@ class CenterNetMetaArch(model.DetectionModel):
width=input_width,
gt_classes_list=gt_classes_list,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list)
gt_weights_list=gt_weights_list,
maximum_normalized_coordinate=maximum_normalized_coordinate)
else:
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
heatmap_targets = assigner.assign_center_targets_from_boxes(
......@@ -2760,7 +2765,8 @@ class CenterNetMetaArch(model.DetectionModel):
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_classes_list=gt_classes_list,
gt_weights_list=gt_weights_list)
gt_weights_list=gt_weights_list,
maximum_normalized_coordinate=maximum_normalized_coordinate)
flattened_heatmap_targets = _flatten_spatial_dimensions(heatmap_targets)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
......@@ -3577,7 +3583,9 @@ class CenterNetMetaArch(model.DetectionModel):
self._batched_prediction_tensor_names = predictions.keys()
return predictions
def loss(self, prediction_dict, true_image_shapes, scope=None):
def loss(
self, prediction_dict, true_image_shapes, scope=None,
maximum_normalized_coordinate=1.1):
"""Computes scalar loss tensors with respect to provided groundtruth.
This function implements the various CenterNet losses.
......@@ -3589,6 +3597,9 @@ class CenterNetMetaArch(model.DetectionModel):
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
scope: Optional scope name.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1. This is used to check bounds during
converting normalized coordinates to absolute coordinates.
Returns:
A dictionary mapping the keys [
......@@ -3616,7 +3627,7 @@ class CenterNetMetaArch(model.DetectionModel):
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes = tf.ceil(
tf.to_float(true_image_shapes) / self._stride)
tf.cast(true_image_shapes, tf.float32) / self._stride)
valid_anchor_weights = get_valid_anchor_weights_in_flattened_image(
output_true_image_shapes, output_height, output_width)
valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)
......@@ -3625,7 +3636,8 @@ class CenterNetMetaArch(model.DetectionModel):
object_center_predictions=prediction_dict[OBJECT_CENTER],
input_height=input_height,
input_width=input_width,
per_pixel_weights=valid_anchor_weights)
per_pixel_weights=valid_anchor_weights,
maximum_normalized_coordinate=maximum_normalized_coordinate)
losses = {
OBJECT_CENTER:
self._center_params.object_center_loss_weight * object_center_loss
......@@ -3765,8 +3777,8 @@ class CenterNetMetaArch(model.DetectionModel):
k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
num_detections = tf.reduce_sum(tf.to_int32(detection_scores > 0), axis=1)
num_detections = tf.reduce_sum(
tf.cast(detection_scores > 0, tf.int32), axis=1)
postprocess_dict = {
fields.DetectionResultFields.detection_scores: detection_scores,
fields.DetectionResultFields.detection_multiclass_scores:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment