Unverified Commit 440e0eec authored by Stephen Wu's avatar Stephen Wu Committed by GitHub
Browse files

Merge branch 'master' into RTESuperGLUE

parents 51364cdf 9815ea67
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
...@@ -24,7 +25,7 @@ from official.vision import keras_cv ...@@ -24,7 +25,7 @@ from official.vision import keras_cv
from official.vision.beta.configs import retinanet as exp_cfg from official.vision.beta.configs import retinanet as exp_cfg
from official.vision.beta.dataloaders import retinanet_input from official.vision.beta.dataloaders import retinanet_input
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import dataset_fn from official.vision.beta.dataloaders import tfds_detection_decoders
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -77,13 +78,22 @@ class RetinaNetTask(base_task.Task): ...@@ -77,13 +78,22 @@ class RetinaNetTask(base_task.Task):
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
else: else:
assert "Only 'all' or 'backbone' can be used to initialize the model." raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Build input dataset.""" """Build input dataset."""
if params.tfds_name:
if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
...@@ -93,7 +103,8 @@ class RetinaNetTask(base_task.Task): ...@@ -93,7 +103,8 @@ class RetinaNetTask(base_task.Task):
label_map=decoder_cfg.label_map, label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id) regenerate_source_id=decoder_cfg.regenerate_source_id)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = retinanet_input.Parser( parser = retinanet_input.Parser(
output_size=self.task_config.model.input_size[:2], output_size=self.task_config.model.input_size[:2],
...@@ -169,10 +180,13 @@ class RetinaNetTask(base_task.Task): ...@@ -169,10 +180,13 @@ class RetinaNetTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32)) metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
if not training: if not training:
if self.task_config.validation_data.tfds_name and self.task_config.annotation_file:
raise ValueError(
"Can't evaluate using annotation file when TFDS is used.")
self.coco_metric = coco_evaluator.COCOEvaluator( self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file, annotation_file=self.task_config.annotation_file,
include_mask=False, include_mask=False,
per_category_metrics=self._task_config.per_category_metrics) per_category_metrics=self.task_config.per_category_metrics)
return metrics return metrics
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.beta.dataloaders import segmentation_input from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.dataloaders import dataset_fn from official.vision.beta.dataloaders import tfds_segmentation_decoders
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -84,7 +84,15 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -84,7 +84,15 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
if params.tfds_name:
if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder = segmentation_input.Decoder() decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser( parser = segmentation_input.Parser(
output_size=params.output_size, output_size=params.output_size,
train_on_crops=params.train_on_crops, train_on_crops=params.train_on_crops,
......
...@@ -195,10 +195,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -195,10 +195,7 @@ class VideoClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
if self.task_config.train_data.output_audio:
outputs = model(features, training=True) outputs = model(features, training=True)
else:
outputs = model(features['image'], training=True)
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(
...@@ -267,10 +264,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -267,10 +264,7 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, features, model): def inference_step(self, features, model):
"""Performs the forward step.""" """Performs the forward step."""
if self.task_config.train_data.output_audio:
outputs = model(features, training=False) outputs = model(features, training=False)
else:
outputs = model(features['image'], training=False)
if self.task_config.train_data.is_multilabel: if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs) outputs = tf.math.sigmoid(outputs)
else: else:
......
...@@ -259,7 +259,7 @@ class Controller: ...@@ -259,7 +259,7 @@ class Controller:
elapsed = time.time() - start elapsed = time.time() - start
_log(f" eval | step: {current_step: 6d} | " _log(f" eval | step: {current_step: 6d} | "
f"eval time: {elapsed: 6.1f} | " f"eval time: {elapsed: 6.1f} sec | "
f"output: {_format_output(eval_output)}") f"output: {_format_output(eval_output)}")
self.eval_summary_manager.write_summaries(eval_output) self.eval_summary_manager.write_summaries(eval_output)
......
...@@ -227,7 +227,7 @@ def _build_classification_loss(loss_config): ...@@ -227,7 +227,7 @@ def _build_classification_loss(loss_config):
if loss_type == 'weighted_sigmoid': if loss_type == 'weighted_sigmoid':
return losses.WeightedSigmoidClassificationLoss() return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_sigmoid_focal': elif loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal config = loss_config.weighted_sigmoid_focal
alpha = None alpha = None
if config.HasField('alpha'): if config.HasField('alpha'):
...@@ -236,25 +236,31 @@ def _build_classification_loss(loss_config): ...@@ -236,25 +236,31 @@ def _build_classification_loss(loss_config):
gamma=config.gamma, gamma=config.gamma,
alpha=alpha) alpha=alpha)
if loss_type == 'weighted_softmax': elif loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss( return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale) logit_scale=config.logit_scale)
if loss_type == 'weighted_logits_softmax': elif loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss( return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale) logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid': elif loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss( return losses.BootstrappedSigmoidClassificationLoss(
alpha=config.alpha, alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft')) bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
if loss_type == 'penalty_reduced_logistic_focal_loss': elif loss_type == 'penalty_reduced_logistic_focal_loss':
config = loss_config.penalty_reduced_logistic_focal_loss config = loss_config.penalty_reduced_logistic_focal_loss
return losses.PenaltyReducedLogisticFocalLoss( return losses.PenaltyReducedLogisticFocalLoss(
alpha=config.alpha, beta=config.beta) alpha=config.alpha, beta=config.beta)
elif loss_type == 'weighted_dice_classification_loss':
config = loss_config.weighted_dice_classification_loss
return losses.WeightedDiceClassificationLoss(
squared_normalization=config.squared_normalization)
else:
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -298,6 +298,45 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -298,6 +298,45 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
losses_builder.build(losses_proto) losses_builder.build(losses_proto)
def test_build_penalty_reduced_logistic_focal_loss(self):
losses_text_proto = """
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertIsInstance(classification_loss,
losses.PenaltyReducedLogisticFocalLoss)
self.assertAlmostEqual(classification_loss._alpha, 2.0)
self.assertAlmostEqual(classification_loss._beta, 4.0)
def test_build_dice_loss(self):
losses_text_proto = """
classification_loss {
weighted_dice_classification_loss {
squared_normalization: true
}
}
localization_loss {
l1_localization_loss {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertIsInstance(classification_loss,
losses.WeightedDiceClassificationLoss)
assert classification_loss._squared_normalization
class HardExampleMinerBuilderTest(tf.test.TestCase): class HardExampleMinerBuilderTest(tf.test.TestCase):
......
...@@ -278,6 +278,79 @@ class WeightedSigmoidClassificationLoss(Loss): ...@@ -278,6 +278,79 @@ class WeightedSigmoidClassificationLoss(Loss):
return per_entry_cross_ent * weights return per_entry_cross_ent * weights
class WeightedDiceClassificationLoss(Loss):
"""Dice loss for classification [1][2].
[1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
[2]: https://arxiv.org/abs/1606.04797
"""
def __init__(self, squared_normalization):
"""Initializes the loss object.
Args:
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
"""
self._squared_normalization = squared_normalization
super(WeightedDiceClassificationLoss, self).__init__()
def _compute_loss(self,
prediction_tensor,
target_tensor,
weights,
class_indices=None):
"""Computes the loss value.
Dice loss uses the area of the ground truth and prediction tensors for
normalization. We compute area by summing along the anchors (2nd) dimension.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing the predicted logits for each class.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
target_tensor: A float tensor of shape [batch_size, num_pixels,
num_classes] representing one-hot encoded classification targets.
num_pixels denotes the total number of pixels in the spatial dimensions
of the mask after flattening.
weights: a float tensor of shape, either [batch_size, num_anchors,
num_classes] or [batch_size, num_anchors, 1]. If the shape is
[batch_size, num_anchors, 1], all the classses are equally weighted.
class_indices: (Optional) A 1-D integer tensor of class indices.
If provided, computes loss only for the specified class indices.
Returns:
loss: a float tensor of shape [batch_size, num_classes]
representing the value of the loss function.
"""
if class_indices is not None:
weights *= tf.reshape(
ops.indices_to_dense_vector(class_indices,
tf.shape(prediction_tensor)[2]),
[1, 1, -1])
prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self._squared_normalization:
prob_tensor = tf.pow(prob_tensor, 2)
target_tensor = tf.pow(target_tensor, 2)
prob_tensor *= weights
target_tensor *= weights
prediction_area = tf.reduce_sum(prob_tensor, axis=1)
gt_area = tf.reduce_sum(target_tensor, axis=1)
intersection = tf.reduce_sum(prob_tensor * target_tensor, axis=1)
dice_coeff = 2 * intersection / tf.maximum(gt_area + prediction_area, 1.0)
dice_loss = 1 - dice_coeff
return dice_loss
class SigmoidFocalClassificationLoss(Loss): class SigmoidFocalClassificationLoss(Loss):
"""Sigmoid focal cross entropy loss. """Sigmoid focal cross entropy loss.
......
...@@ -1447,5 +1447,111 @@ class L1LocalizationLossTest(test_case.TestCase): ...@@ -1447,5 +1447,111 @@ class L1LocalizationLossTest(test_case.TestCase):
self.assertAllClose(computed_value, [[0.8, 0.0], [0.6, 0.1]], rtol=1e-6) self.assertAllClose(computed_value, [[0.8, 0.0], [0.6, 0.1]], rtol=1e-6)
class WeightedDiceClassificationLoss(test_case.TestCase):
def test_compute_weights_1(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
return loss._compute_loss(pred, target, weights)
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 0] = 2 * 0.9 / 2.5
dice_coeff[0, 2] = 2 * 0.5 / 2.5
dice_coeff[0, 3] = 2 * 0.1 / 2.1
dice_coeff[1, 3] = 2 * 0.2 / 2.2
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
def test_compute_weights_set(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
weights[:, :, 0] = 0.0
return loss._compute_loss(pred, target, weights)
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 2] = 2 * 0.5 / 2.5
dice_coeff[0, 3] = 2 * 0.1 / 2.1
dice_coeff[1, 3] = 2 * 0.2 / 2.2
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
def test_class_indices(self):
def graph_fn():
loss = losses.WeightedDiceClassificationLoss(squared_normalization=False)
pred = np.zeros((2, 3, 4), dtype=np.float32)
target = np.zeros((2, 3, 4), dtype=np.float32)
pred[0, 1, 0] = _logit(0.9)
pred[0, 2, 0] = _logit(0.1)
pred[0, 2, 2] = _logit(0.5)
pred[0, 1, 3] = _logit(0.1)
pred[1, 2, 3] = _logit(0.2)
pred[1, 1, 1] = _logit(0.3)
pred[1, 0, 2] = _logit(0.1)
target[0, 1, 0] = 1.0
target[0, 2, 2] = 1.0
target[0, 1, 3] = 1.0
target[1, 2, 3] = 1.0
target[1, 1, 1] = 0.0
target[1, 0, 2] = 0.0
weights = np.ones_like(target)
return loss._compute_loss(pred, target, weights, class_indices=[0])
dice_coeff = np.zeros((2, 4))
dice_coeff[0, 0] = 2 * 0.9 / 2.5
computed_value = self.execute(graph_fn, [])
self.assertAllClose(computed_value, 1 - dice_coeff, rtol=1e-6)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -1468,6 +1468,175 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1468,6 +1468,175 @@ class CenterNetKeypointTargetAssigner(object):
batch_offsets = tf.concat(batch_offsets, axis=0) batch_offsets = tf.concat(batch_offsets, axis=0)
return (batch_indices, batch_offsets, batch_weights) return (batch_indices, batch_offsets, batch_weights)
def assign_keypoints_depth_targets(self,
height,
width,
gt_keypoints_list,
gt_classes_list,
gt_keypoint_depths_list,
gt_keypoint_depth_weights_list,
gt_keypoints_weights_list=None,
gt_weights_list=None):
"""Returns the target depths of the keypoints.
The returned values are the relative depth information of each keypoints.
Args:
height: int, height of input to the CenterNet model. This is used to
determine the height of the output.
width: int, width of the input to the CenterNet model. This is used to
determine the width of the output.
gt_keypoints_list: A list of tensors with shape [num_instances,
num_total_keypoints, 2]. See class-level description for more detail.
gt_classes_list: A list of tensors with shape [num_instances,
num_classes]. See class-level description for more detail.
gt_keypoint_depths_list: A list of tensors with shape [num_instances,
num_total_keypoints] corresponding to the relative depth of the
keypoints.
gt_keypoint_depth_weights_list: A list of tensors with shape
[num_instances, num_total_keypoints] corresponding to the weights of
the relative depth.
gt_keypoints_weights_list: A list of tensors with shape [num_instances,
num_total_keypoints] corresponding to the weight of each keypoint.
gt_weights_list: A list of float tensors with shape [num_instances]. See
class-level description for more detail.
Returns:
batch_indices: an integer tensor of shape [num_total_instances, 3] (or
[num_total_instances, 4] if 'per_keypoint_offset' is set True) holding
the indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively. The fourth column corresponds to the channel
dimension (if 'per_keypoint_offset' is set True).
batch_depths: a float tensor of shape [num_total_instances, 1] indicating
the target depth of each keypoint.
batch_weights: a float tensor of shape [num_total_instances] indicating
the weight of each prediction.
Note that num_total_instances = batch_size * num_instances *
num_keypoints * num_neighbors
"""
batch_indices = []
batch_weights = []
batch_depths = []
if gt_keypoints_weights_list is None:
gt_keypoints_weights_list = [None] * len(gt_keypoints_list)
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_classes_list)
if gt_keypoint_depths_list is None:
gt_keypoint_depths_list = [None] * len(gt_classes_list)
for i, (keypoints, classes, kp_weights, weights,
keypoint_depths, keypoint_depth_weights) in enumerate(
zip(gt_keypoints_list, gt_classes_list,
gt_keypoints_weights_list, gt_weights_list,
gt_keypoint_depths_list, gt_keypoint_depth_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
# [num_instances * num_keypoints]
y_source = tf.keras.backend.flatten(keypoints_absolute[:, :, 0])
x_source = tf.keras.backend.flatten(keypoints_absolute[:, :, 1])
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
y_source_neighbors)
# Update the valid keypoint weights.
# [num_instance * num_keypoints, num_neighbors]
valid_keypoints = tf.cast(
valid_sources, dtype=tf.float32) * tf.stack(
[tf.keras.backend.flatten(kp_weights)] * num_neighbors, axis=-1)
# Compute the offsets and indices of the box centers. Shape:
# indices: [num_instances * num_keypoints, num_neighbors, 2]
_, indices = ta_utils.compute_floor_offsets_with_indices(
y_source=y_source_neighbors,
x_source=x_source_neighbors,
y_target=y_source,
x_target=x_source)
# Reshape to:
# indices: [num_instances * num_keypoints * num_neighbors, 2]
indices = tf.reshape(indices, [-1, 2])
# Gather the keypoint depth from corresponding keypoint indices:
# [num_instances, num_keypoints]
keypoint_depths = tf.gather(
keypoint_depths, self._keypoint_indices, axis=1)
# Tile the depth target to surrounding pixels.
# [num_instances, num_keypoints, num_neighbors]
tiled_keypoint_depths = tf.tile(
tf.expand_dims(keypoint_depths, axis=-1),
multiples=[1, 1, num_neighbors])
# [num_instances, num_keypoints]
keypoint_depth_weights = tf.gather(
keypoint_depth_weights, self._keypoint_indices, axis=1)
# [num_instances, num_keypoints, num_neighbors]
keypoint_depth_weights = tf.tile(
tf.expand_dims(keypoint_depth_weights, axis=-1),
multiples=[1, 1, num_neighbors])
# Update the weights of keypoint depth by the weights of the keypoints.
# A keypoint depth target is valid only if its corresponding keypoint
# target is also valid.
# [num_instances, num_keypoints, num_neighbors]
tiled_depth_weights = (
tf.reshape(valid_keypoints,
[num_instances, num_keypoints, num_neighbors]) *
keypoint_depth_weights)
invalid_depths = tf.logical_or(
tf.math.is_nan(tiled_depth_weights),
tf.math.is_nan(tiled_keypoint_depths))
# Assign zero values and weights to NaN values.
final_keypoint_depths = tf.where(invalid_depths,
tf.zeros_like(tiled_keypoint_depths),
tiled_keypoint_depths)
final_keypoint_depth_weights = tf.where(
invalid_depths,
tf.zeros_like(tiled_depth_weights),
tiled_depth_weights)
# [num_instances * num_keypoints * num_neighbors, 1]
batch_depths.append(tf.reshape(final_keypoint_depths, [-1, 1]))
# Prepare the batch indices to be prepended.
batch_index = tf.fill(
[num_instances * num_keypoints * num_neighbors, 1], i)
if self._per_keypoint_offset:
tiled_keypoint_types = self._get_keypoint_types(
num_instances, num_keypoints, num_neighbors)
batch_indices.append(
tf.concat([batch_index, indices,
tf.reshape(tiled_keypoint_types, [-1, 1])], axis=1))
else:
batch_indices.append(tf.concat([batch_index, indices], axis=1))
batch_weights.append(
tf.keras.backend.flatten(final_keypoint_depth_weights))
# Concatenate the tensors in the batch in the first dimension:
# shape: [batch_size * num_instances * num_keypoints * num_neighbors, 3] or
# [batch_size * num_instances * num_keypoints * num_neighbors, 4] if
# 'per_keypoint_offset' is set to True.
batch_indices = tf.concat(batch_indices, axis=0)
# shape: [batch_size * num_instances * num_keypoints * num_neighbors]
batch_weights = tf.concat(batch_weights, axis=0)
# shape: [batch_size * num_instances * num_keypoints * num_neighbors, 1]
batch_depths = tf.concat(batch_depths, axis=0)
return (batch_indices, batch_depths, batch_weights)
def assign_joint_regression_targets(self, def assign_joint_regression_targets(self,
height, height,
width, width,
......
...@@ -1683,6 +1683,121 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1683,6 +1683,121 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
np.testing.assert_array_equal([0, 3, 2], indices[7, :]) np.testing.assert_array_equal([0, 3, 2], indices[7, :])
np.testing.assert_array_almost_equal([0.6, 0.4], offsets[7, :]) np.testing.assert_array_almost_equal([0.6, 0.4], offsets[7, :])
def test_assign_keypoint_depths_target(self):
def graph_fn():
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, 0.7, 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
depths = tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[0.5, 0.0, 7.0, 0.7, 0.4]]),
dtype=tf.float32)
gt_keypoint_depths_list = [depths]
gt_keypoint_depth_weights = tf.constant(
np.array([[1.0, 1.0, 1.0, 1.0, 1.0],
[float('nan'), 0.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.5, 1.0, 1.0]]),
dtype=tf.float32)
gt_keypoint_depth_weights_list = [gt_keypoint_depth_weights]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2],
peak_radius=1)
(indices, depths, weights) = cn_assigner.assign_keypoints_depth_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list,
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
return indices, depths, weights
indices, depths, weights = self.execute(graph_fn, [])
# Only the last 5 elements has positive weight.
np.testing.assert_array_almost_equal([
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5
], weights)
# Validate the last 5 elements' depth value.
np.testing.assert_array_almost_equal(
[7.0, 7.0, 7.0, 7.0, 7.0], depths[35:, 0])
self.assertEqual((40, 3), indices.shape)
np.testing.assert_array_equal([0, 2, 2], indices[35, :])
def test_assign_keypoint_depths_per_keypoints(self):
def graph_fn():
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, 0.7, 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
depths = tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[0.5, 0.0, 7.0, 0.7, 0.4]]),
dtype=tf.float32)
gt_keypoint_depths_list = [depths]
gt_keypoint_depth_weights = tf.constant(
np.array([[1.0, 1.0, 1.0, 1.0, 1.0],
[float('nan'), 0.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.5, 1.0, 1.0]]),
dtype=tf.float32)
gt_keypoint_depth_weights_list = [gt_keypoint_depth_weights]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2],
peak_radius=1,
per_keypoint_offset=True)
(indices, depths, weights) = cn_assigner.assign_keypoints_depth_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list,
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
return indices, depths, weights
indices, depths, weights = self.execute(graph_fn, [])
# Only the last 5 elements has positive weight.
np.testing.assert_array_almost_equal([
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5, 0.5
], weights)
# Validate the last 5 elements' depth value.
np.testing.assert_array_almost_equal(
[7.0, 7.0, 7.0, 7.0, 7.0], depths[35:, 0])
self.assertEqual((40, 4), indices.shape)
np.testing.assert_array_equal([0, 2, 2, 1], indices[35, :])
def test_assign_keypoints_offset_targets_radius(self): def test_assign_keypoints_offset_targets_radius(self):
def graph_fn(): def graph_fn():
gt_classes_list = [ gt_classes_list = [
......
...@@ -145,7 +145,7 @@ class SSDModule(tf.Module): ...@@ -145,7 +145,7 @@ class SSDModule(tf.Module):
scores = tf.constant(0.0, dtype=tf.float32, name='scores') scores = tf.constant(0.0, dtype=tf.float32, name='scores')
classes = tf.constant(0.0, dtype=tf.float32, name='classes') classes = tf.constant(0.0, dtype=tf.float32, name='classes')
num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections') num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
return boxes, scores, classes, num_detections return boxes, classes, scores, num_detections
return dummy_post_processing return dummy_post_processing
......
...@@ -68,7 +68,8 @@ def _multiclass_scores_or_one_hot_labels(multiclass_scores, ...@@ -68,7 +68,8 @@ def _multiclass_scores_or_one_hot_labels(multiclass_scores,
return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn) return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn)
def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes, def convert_labeled_classes_to_k_hot(groundtruth_labeled_classes,
num_classes,
map_empty_to_ones=False): map_empty_to_ones=False):
"""Returns k-hot encoding of the labeled classes. """Returns k-hot encoding of the labeled classes.
...@@ -235,7 +236,7 @@ def transform_input_data(tensor_dict, ...@@ -235,7 +236,7 @@ def transform_input_data(tensor_dict,
if field in out_tensor_dict: if field in out_tensor_dict:
out_tensor_dict[field] = _remove_unrecognized_classes( out_tensor_dict[field] = _remove_unrecognized_classes(
out_tensor_dict[field], unrecognized_label=-1) out_tensor_dict[field], unrecognized_label=-1)
out_tensor_dict[field] = _convert_labeled_classes_to_k_hot( out_tensor_dict[field] = convert_labeled_classes_to_k_hot(
out_tensor_dict[field], num_classes, map_empty_to_ones) out_tensor_dict[field], num_classes, map_empty_to_ones)
if input_fields.multiclass_scores in out_tensor_dict: if input_fields.multiclass_scores in out_tensor_dict:
......
...@@ -152,6 +152,15 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -152,6 +152,15 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack( groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.keypoints)) detection_model.groundtruth_lists(fields.BoxListFields.keypoints))
if detection_model.groundtruth_has_field(
fields.BoxListFields.keypoint_depths):
groundtruth[input_data_fields.groundtruth_keypoint_depths] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.keypoint_depths))
groundtruth[
input_data_fields.groundtruth_keypoint_depth_weights] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.keypoint_depth_weights))
if detection_model.groundtruth_has_field( if detection_model.groundtruth_has_field(
fields.BoxListFields.keypoint_visibilities): fields.BoxListFields.keypoint_visibilities):
groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack( groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack(
...@@ -260,6 +269,8 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True): ...@@ -260,6 +269,8 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints, fields.InputDataFields.groundtruth_keypoints,
fields.InputDataFields.groundtruth_keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights,
fields.InputDataFields.groundtruth_keypoint_visibilities, fields.InputDataFields.groundtruth_keypoint_visibilities,
fields.InputDataFields.groundtruth_dp_num_points, fields.InputDataFields.groundtruth_dp_num_points,
fields.InputDataFields.groundtruth_dp_part_ids, fields.InputDataFields.groundtruth_dp_part_ids,
...@@ -311,6 +322,13 @@ def provide_groundtruth(model, labels): ...@@ -311,6 +322,13 @@ def provide_groundtruth(model, labels):
gt_keypoints_list = None gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels: if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints] gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
gt_keypoint_depths_list = None
gt_keypoint_depth_weights_list = None
if fields.InputDataFields.groundtruth_keypoint_depths in labels:
gt_keypoint_depths_list = (
labels[fields.InputDataFields.groundtruth_keypoint_depths])
gt_keypoint_depth_weights_list = (
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights])
gt_keypoint_visibilities_list = None gt_keypoint_visibilities_list = None
if fields.InputDataFields.groundtruth_keypoint_visibilities in labels: if fields.InputDataFields.groundtruth_keypoint_visibilities in labels:
gt_keypoint_visibilities_list = labels[ gt_keypoint_visibilities_list = labels[
...@@ -376,7 +394,9 @@ def provide_groundtruth(model, labels): ...@@ -376,7 +394,9 @@ def provide_groundtruth(model, labels):
groundtruth_area_list=gt_area_list, groundtruth_area_list=gt_area_list,
groundtruth_track_ids_list=gt_track_ids_list, groundtruth_track_ids_list=gt_track_ids_list,
groundtruth_verified_neg_classes=gt_verified_neg_classes, groundtruth_verified_neg_classes=gt_verified_neg_classes,
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes) groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
......
...@@ -99,6 +99,10 @@ def _compute_losses_and_predictions_dicts( ...@@ -99,6 +99,10 @@ def _compute_losses_and_predictions_dicts(
k-hot tensor of classes. k-hot tensor of classes.
labels[fields.InputDataFields.groundtruth_track_ids] is a int32 labels[fields.InputDataFields.groundtruth_track_ids] is a int32
tensor of track IDs. tensor of track IDs.
labels[fields.InputDataFields.groundtruth_keypoint_depths] is a
float32 tensor containing keypoint depths information.
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights] is a
float32 tensor containing the weights of the keypoint depth feature.
add_regularization_loss: Whether or not to include the model's add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary. regularization loss in the losses dictionary.
...@@ -213,6 +217,10 @@ def eager_train_step(detection_model, ...@@ -213,6 +217,10 @@ def eager_train_step(detection_model,
k-hot tensor of classes. k-hot tensor of classes.
labels[fields.InputDataFields.groundtruth_track_ids] is a int32 labels[fields.InputDataFields.groundtruth_track_ids] is a int32
tensor of track IDs. tensor of track IDs.
labels[fields.InputDataFields.groundtruth_keypoint_depths] is a
float32 tensor containing keypoint depths information.
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights] is a
float32 tensor containing the weights of the keypoint depth feature.
unpad_groundtruth_tensors: A parameter passed to unstack_batch. unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables. optimizer: The training optimizer that will update the variables.
learning_rate: The learning rate tensor for the current training step. learning_rate: The learning rate tensor for the current training step.
......
...@@ -110,6 +110,7 @@ message ClassificationLoss { ...@@ -110,6 +110,7 @@ message ClassificationLoss {
BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3; BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4; SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
PenaltyReducedLogisticFocalLoss penalty_reduced_logistic_focal_loss = 6; PenaltyReducedLogisticFocalLoss penalty_reduced_logistic_focal_loss = 6;
WeightedDiceClassificationLoss weighted_dice_classification_loss = 7;
} }
} }
...@@ -217,3 +218,14 @@ message RandomExampleSampler { ...@@ -217,3 +218,14 @@ message RandomExampleSampler {
// example sampling. // example sampling.
optional float positive_sample_fraction = 1 [default = 0.01]; optional float positive_sample_fraction = 1 [default = 0.01];
} }
// Dice loss for training instance masks[1][2].
// [1]: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
// [2]: https://arxiv.org/abs/1606.04797
message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for
// normalization.
optional bool squared_normalization = 1 [default=false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment