Commit 27347c2d authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

cleanup for keras_cv.losses

PiperOrigin-RevId: 331571062
parent 63ef92aa
...@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task): ...@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task):
def build_losses(self, outputs, labels, aux_losses=None): def build_losses(self, outputs, labels, aux_losses=None):
"""Build RetinaNet losses.""" """Build RetinaNet losses."""
params = self.task_config params = self.task_config
cls_loss_fn = keras_cv.FocalLoss( cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha, alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma, gamma=params.losses.focal_loss_gamma,
reduction=tf.keras.losses.Reduction.SUM) reduction=tf.keras.losses.Reduction.SUM)
...@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task): ...@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task):
num_positives = tf.reduce_sum(box_sample_weight) + 1.0 num_positives = tf.reduce_sum(box_sample_weight) + 1.0
cls_sample_weight = cls_sample_weight / num_positives cls_sample_weight = cls_sample_weight / num_positives
box_sample_weight = box_sample_weight / num_positives box_sample_weight = box_sample_weight / num_positives
y_true_cls = keras_cv.multi_level_flatten( y_true_cls = keras_cv.losses.multi_level_flatten(
labels['cls_targets'], last_dim=None) labels['cls_targets'], last_dim=None)
y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes) y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
y_pred_cls = keras_cv.multi_level_flatten( y_pred_cls = keras_cv.losses.multi_level_flatten(
outputs['cls_outputs'], last_dim=params.model.num_classes) outputs['cls_outputs'], last_dim=params.model.num_classes)
y_true_box = keras_cv.multi_level_flatten( y_true_box = keras_cv.losses.multi_level_flatten(
labels['box_targets'], last_dim=4) labels['box_targets'], last_dim=4)
y_pred_box = keras_cv.multi_level_flatten( y_pred_box = keras_cv.losses.multi_level_flatten(
outputs['box_outputs'], last_dim=4) outputs['box_outputs'], last_dim=4)
cls_loss = cls_loss_fn( cls_loss = cls_loss_fn(
......
...@@ -14,5 +14,5 @@ ...@@ -14,5 +14,5 @@
# ============================================================================== # ==============================================================================
"""Keras-CV package definition.""" """Keras-CV package definition."""
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from official.vision.keras_cv.losses import * from official.vision.keras_cv import losses
from official.vision.keras_cv.ops import * from official.vision.keras_cv import ops
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
# ============================================================================== # ==============================================================================
"""Keras-CV layers package definition.""" """Keras-CV layers package definition."""
from official.vision.keras_cv.losses.focal_loss import FocalLoss from official.vision.keras_cv.losses.focal_loss import FocalLoss
from official.vision.keras_cv.losses.loss_utils import * from official.vision.keras_cv.losses.loss_utils import multi_level_flatten
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
# ============================================================================== # ==============================================================================
"""Losses used for detection models.""" """Losses used for detection models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
import tensorflow as tf import tensorflow as tf
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Losses utilities for detection models.""" """Losses utilities for detection models."""
# Import libraries
import tensorflow as tf import tensorflow as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment