Commit aa7749c5 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

cleanup for keras_cv.losses

PiperOrigin-RevId: 331571062
parent faf7a4f5
......@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task):
def build_losses(self, outputs, labels, aux_losses=None):
"""Build RetinaNet losses."""
params = self.task_config
cls_loss_fn = keras_cv.FocalLoss(
cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
reduction=tf.keras.losses.Reduction.SUM)
......@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task):
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
cls_sample_weight = cls_sample_weight / num_positives
box_sample_weight = box_sample_weight / num_positives
y_true_cls = keras_cv.multi_level_flatten(
y_true_cls = keras_cv.losses.multi_level_flatten(
labels['cls_targets'], last_dim=None)
y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
y_pred_cls = keras_cv.multi_level_flatten(
y_pred_cls = keras_cv.losses.multi_level_flatten(
outputs['cls_outputs'], last_dim=params.model.num_classes)
y_true_box = keras_cv.multi_level_flatten(
y_true_box = keras_cv.losses.multi_level_flatten(
labels['box_targets'], last_dim=4)
y_pred_box = keras_cv.multi_level_flatten(
y_pred_box = keras_cv.losses.multi_level_flatten(
outputs['box_outputs'], last_dim=4)
cls_loss = cls_loss_fn(
......
......@@ -14,5 +14,5 @@
# ==============================================================================
"""Keras-CV package definition."""
# pylint: disable=wildcard-import
from official.vision.keras_cv.losses import *
from official.vision.keras_cv.ops import *
from official.vision.keras_cv import losses
from official.vision.keras_cv import ops
......@@ -14,4 +14,4 @@
# ==============================================================================
"""Keras-CV layers package definition."""
from official.vision.keras_cv.losses.focal_loss import FocalLoss
from official.vision.keras_cv.losses.loss_utils import *
from official.vision.keras_cv.losses.loss_utils import multi_level_flatten
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,11 +14,6 @@
# ==============================================================================
"""Losses used for detection models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries
import tensorflow as tf
......
......@@ -14,7 +14,6 @@
# ==============================================================================
"""Losses utilities for detection models."""
# Import libraries
import tensorflow as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment