Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b1598e9e
"vscode:/vscode.git/clone" did not exist on "93b2af7f6e6aed2464a6ee7c5d9dd000d557d47e"
Commit
b1598e9e
authored
Sep 02, 2020
by
Zhenyu Tan
Committed by
A. Unique TensorFlower
Sep 02, 2020
Browse files
Internal change
PiperOrigin-RevId: 329763594
parent
cc748b2a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
181 additions
and
10 deletions
+181
-10
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+15
-10
official/vision/keras_cv/__init__.py
official/vision/keras_cv/__init__.py
+17
-0
official/vision/keras_cv/losses/__init__.py
official/vision/keras_cv/losses/__init__.py
+17
-0
official/vision/keras_cv/losses/focal_loss.py
official/vision/keras_cv/losses/focal_loss.py
+89
-0
official/vision/keras_cv/losses/loss_utils.py
official/vision/keras_cv/losses/loss_utils.py
+43
-0
No files found.
official/vision/beta/tasks/retinanet.py
View file @
b1598e9e
...
@@ -20,12 +20,12 @@ import tensorflow as tf
...
@@ -20,12 +20,12 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.vision
import
keras_cv
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.losses
import
retinanet_losses
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
...
@@ -131,12 +131,11 @@ class RetinaNetTask(base_task.Task):
...
@@ -131,12 +131,11 @@ class RetinaNetTask(base_task.Task):
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
"""Build RetinaNet losses."""
"""Build RetinaNet losses."""
params
=
self
.
task_config
params
=
self
.
task_config
cls_loss_fn
=
retinanet_losses
.
FocalLoss
(
cls_loss_fn
=
keras_cv
.
FocalLoss
(
alpha
=
params
.
losses
.
focal_loss_alpha
,
alpha
=
params
.
losses
.
focal_loss_alpha
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
num_classes
=
params
.
model
.
num_classes
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
box_loss_fn
=
retinanet_losses
.
RetinanetBoxLoss
(
box_loss_fn
=
tf
.
keras
.
losses
.
Huber
(
params
.
losses
.
huber_loss_delta
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
params
.
losses
.
huber_loss_delta
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
# Sums all positives in a batch for normalization and avoids zero
# Sums all positives in a batch for normalization and avoids zero
...
@@ -146,14 +145,20 @@ class RetinaNetTask(base_task.Task):
...
@@ -146,14 +145,20 @@ class RetinaNetTask(base_task.Task):
num_positives
=
tf
.
reduce_sum
(
box_sample_weight
)
+
1.0
num_positives
=
tf
.
reduce_sum
(
box_sample_weight
)
+
1.0
cls_sample_weight
=
cls_sample_weight
/
num_positives
cls_sample_weight
=
cls_sample_weight
/
num_positives
box_sample_weight
=
box_sample_weight
/
num_positives
box_sample_weight
=
box_sample_weight
/
num_positives
y_true_cls
=
keras_cv
.
multi_level_flatten
(
labels
[
'cls_targets'
],
last_dim
=
None
)
y_true_cls
=
tf
.
one_hot
(
y_true_cls
,
params
.
model
.
num_classes
)
y_pred_cls
=
keras_cv
.
multi_level_flatten
(
outputs
[
'cls_outputs'
],
last_dim
=
params
.
model
.
num_classes
)
y_true_box
=
keras_cv
.
multi_level_flatten
(
labels
[
'box_targets'
],
last_dim
=
4
)
y_pred_box
=
keras_cv
.
multi_level_flatten
(
outputs
[
'box_outputs'
],
last_dim
=
4
)
cls_loss
=
cls_loss_fn
(
cls_loss
=
cls_loss_fn
(
y_true
=
labels
[
'cls_targets'
],
y_true
=
y_true_cls
,
y_pred
=
y_pred_cls
,
sample_weight
=
cls_sample_weight
)
y_pred
=
outputs
[
'cls_outputs'
],
sample_weight
=
cls_sample_weight
)
box_loss
=
box_loss_fn
(
box_loss
=
box_loss_fn
(
y_true
=
labels
[
'box_targets'
],
y_true
=
y_true_box
,
y_pred
=
y_pred_box
,
sample_weight
=
box_sample_weight
)
y_pred
=
outputs
[
'box_outputs'
],
sample_weight
=
box_sample_weight
)
model_loss
=
cls_loss
+
params
.
losses
.
box_loss_weight
*
box_loss
model_loss
=
cls_loss
+
params
.
losses
.
box_loss_weight
*
box_loss
...
...
official/vision/keras_cv/__init__.py
0 → 100644
View file @
b1598e9e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-NLP package definition."""
# pylint: disable=wildcard-import
from
official.vision.keras_cv.losses
import
*
official/vision/keras_cv/losses/__init__.py
0 → 100644
View file @
b1598e9e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-CV layers package definition."""
from
official.vision.keras_cv.losses.focal_loss
import
FocalLoss
from
official.vision.keras_cv.losses.loss_utils
import
*
official/vision/keras_cv/losses/focal_loss.py
0 → 100644
View file @
b1598e9e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses used for detection models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Import libraries
import
tensorflow
as
tf
class
FocalLoss
(
tf
.
keras
.
losses
.
Loss
):
"""Implements a Focal loss for classification problems.
Reference:
[Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002).
"""
def
__init__
(
self
,
alpha
,
gamma
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
AUTO
,
name
=
None
):
"""Initializes `FocalLoss`.
Arguments:
alpha: The `alpha` weight factor for binary class imbalance.
gamma: The `gamma` focusing parameter to re-weight loss.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
loss. Default value is `AUTO`. `AUTO` indicates that the reduction
option will be determined by the usage context. For almost all cases
this defaults to `SUM_OVER_BATCH_SIZE`. When used with
`tf.distribute.Strategy`, outside of built-in training loops such as
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
will raise an error. Please see this custom training [tutorial](
https://www.tensorflow.org/tutorials/distribute/custom_training) for
more details.
name: Optional name for the op. Defaults to 'retinanet_class_loss'.
"""
self
.
_alpha
=
alpha
self
.
_gamma
=
gamma
super
(
FocalLoss
,
self
).
__init__
(
reduction
=
reduction
,
name
=
name
)
def
call
(
self
,
y_true
,
y_pred
):
"""Invokes the `FocalLoss`.
Arguments:
y_true: A tensor of size [batch, num_anchors, num_classes]
y_pred: A tensor of size [batch, num_anchors, num_classes]
Returns:
Summed loss float `Tensor`.
"""
with
tf
.
name_scope
(
'focal_loss'
):
y_true
=
tf
.
cast
(
y_true
,
dtype
=
tf
.
float32
)
y_pred
=
tf
.
cast
(
y_pred
,
dtype
=
tf
.
float32
)
positive_label_mask
=
tf
.
equal
(
y_true
,
1.0
)
cross_entropy
=
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
y_true
,
logits
=
y_pred
))
probs
=
tf
.
sigmoid
(
y_pred
)
probs_gt
=
tf
.
where
(
positive_label_mask
,
probs
,
1.0
-
probs
)
# With small gamma, the implementation could produce NaN during back prop.
modulator
=
tf
.
pow
(
1.0
-
probs_gt
,
self
.
_gamma
)
loss
=
modulator
*
cross_entropy
weighted_loss
=
tf
.
where
(
positive_label_mask
,
self
.
_alpha
*
loss
,
(
1.0
-
self
.
_alpha
)
*
loss
)
return
weighted_loss
def
get_config
(
self
):
config
=
{
'alpha'
:
self
.
_alpha
,
'gamma'
:
self
.
_gamma
,
}
base_config
=
super
(
FocalLoss
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/vision/keras_cv/losses/loss_utils.py
0 → 100644
View file @
b1598e9e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses utilities for detection models."""
# Import libraries
import
tensorflow
as
tf
def
multi_level_flatten
(
multi_level_inputs
,
last_dim
=
None
):
"""Flattens a multi-level input.
Arguments:
multi_level_inputs: Ordered Dict with level to [batch, d1, ..., dm].
last_dim: Whether the output should be [batch_size, None], or [batch_size,
None, last_dim]. Defaults to `None`.
Returns:
Concatenated output [batch_size, None], or [batch_size, None, dm]
"""
flattened_inputs
=
[]
batch_size
=
None
for
level
in
multi_level_inputs
.
keys
():
single_input
=
multi_level_inputs
[
level
]
if
batch_size
is
None
:
batch_size
=
single_input
.
shape
[
0
]
or
tf
.
shape
(
single_input
)[
0
]
if
last_dim
is
not
None
:
flattened_input
=
tf
.
reshape
(
single_input
,
[
batch_size
,
-
1
,
last_dim
])
else
:
flattened_input
=
tf
.
reshape
(
single_input
,
[
batch_size
,
-
1
])
flattened_inputs
.
append
(
flattened_input
)
return
tf
.
concat
(
flattened_inputs
,
axis
=
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment