Commit 3fc55e9e authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377216027
parent 0a770680
......@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task):
return dataset
def build_attribute_loss(self,
attribute_heads: List[exp_cfg.AttributeHead],
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
box_sample_weight: tf.Tensor) -> float:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss = 0.0
for head in attribute_heads:
if head.name not in labels['attribute_targets']:
raise ValueError(f'Attribute {head.name} not found in label targets.')
if head.name not in outputs['attribute_outputs']:
raise ValueError(f'Attribute {head.name} not found in model outputs.')
y_true_att = keras_cv.losses.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=head.size)
y_pred_att = keras_cv.losses.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size)
if head.type == 'regression':
att_loss_fn = tf.keras.losses.Huber(
1.0, reduction=tf.keras.losses.Reduction.SUM)
att_loss = att_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight)
else:
raise ValueError(f'Attribute type {head.type} not supported.')
attribute_loss += att_loss
return attribute_loss
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build RetinaNet losses."""
params = self.task_config
attribute_heads = self.task_config.model.head.attribute_heads
cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
......@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task):
model_loss = cls_loss + params.losses.box_loss_weight * box_loss
if attribute_heads:
model_loss += self.build_attribute_loss(attribute_heads, outputs, labels,
box_sample_weight)
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment