Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.projects.basnet.evaluation import metrics
class BASNetMetricTest(parameterized.TestCase, tf.test.TestCase):
def test_mae(self):
input_size = 224
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
mae_obj = metrics.MAE()
mae_obj.reset_states()
mae_obj.update_state(labels, inputs)
output = mae_obj.result()
mae_tf = tf.keras.metrics.MeanAbsoluteError()
mae_tf.reset_state()
mae_tf.update_state(labels[0], inputs[0])
compare = mae_tf.result().numpy()
self.assertAlmostEqual(output, compare, places=4)
def test_max_f(self):
input_size = 224
beta = 0.3
inputs = (tf.random.uniform([2, input_size, input_size, 1]),)
labels = (tf.random.uniform([2, input_size, input_size, 1]),)
max_f_obj = metrics.MaxFscore()
max_f_obj.reset_states()
max_f_obj.update_state(labels, inputs)
output = max_f_obj.result()
pre_tf = tf.keras.metrics.Precision(thresholds=0.78)
rec_tf = tf.keras.metrics.Recall(thresholds=0.78)
pre_tf.reset_state()
rec_tf.reset_state()
pre_tf.update_state(labels[0], inputs[0])
rec_tf.update_state(labels[0], inputs[0])
pre_out_tf = pre_tf.result().numpy()
rec_out_tf = rec_tf.result().numpy()
compare = (1+beta)*pre_out_tf*rec_out_tf/(beta*pre_out_tf+rec_out_tf+1e-8)
self.assertAlmostEqual(output, compare, places=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for BASNet models."""
import tensorflow as tf
EPSILON = 1e-5
class BASNetLoss:
"""BASNet hybrid loss."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=False)
self._ssim = tf.image.ssim
def __call__(self, sigmoids, labels):
levels = sorted(sigmoids.keys())
labels_bce = tf.squeeze(labels, axis=-1)
labels = tf.cast(labels, tf.float32)
bce_losses = []
ssim_losses = []
iou_losses = []
for level in levels:
bce_losses.append(
self._binary_crossentropy(labels_bce, sigmoids[level]))
ssim_losses.append(
1 - self._ssim(sigmoids[level], labels, max_val=1.0))
iou_losses.append(
self._iou_loss(sigmoids[level], labels))
total_bce_loss = tf.math.add_n(bce_losses)
total_ssim_loss = tf.math.add_n(ssim_losses)
total_iou_loss = tf.math.add_n(iou_losses)
total_loss = total_bce_loss + total_ssim_loss + total_iou_loss
total_loss = total_loss / len(levels)
return total_loss
def _iou_loss(self, sigmoids, labels):
total_iou_loss = 0
intersection = tf.reduce_sum(sigmoids[:, :, :, :] * labels[:, :, :, :])
union = tf.reduce_sum(sigmoids[:, :, :, :]) + tf.reduce_sum(
labels[:, :, :, :]) - intersection
iou = intersection / union
total_iou_loss += 1-iou
return total_iou_loss
This diff is collapsed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet
class BASNetNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(256),
(512),
)
def test_basnet_network_creation(
self, input_size):
"""Test for creation of a segmentation network."""
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
sigmoids = model(inputs)
levels = sorted(sigmoids.keys())
self.assertAllEqual(
[2, input_size, input_size, 1],
sigmoids[levels[-1]].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
backbone = basnet_model.BASNetEncoder()
decoder = basnet_model.BASNetDecoder()
refinement = refunet.RefUnet()
model = basnet_model.BASNetModel(
backbone=backbone,
decoder=decoder,
refinement=refinement
)
config = model.get_config()
new_model = basnet_model.BASNetModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export module for BASNet."""
import tensorflow as tf
from official.projects.basnet.tasks import basnet
from official.vision.beta.serving import semantic_segmentation
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class BASNetModule(semantic_segmentation.SegmentationModule):
"""BASNet Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return basnet.build_basnet_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
masks = self.inference_step(images)
keys = sorted(masks.keys())
output = tf.image.resize(
masks[keys[-1]],
self._input_image_size, method='bilinear')
return dict(predicted_masks=output)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment