Commit b1d973ea authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382422745
parent e515beb1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFDS factory functions."""
from official.vision.beta.dataloaders import decoder as base_decoder
from official.vision.beta.dataloaders import tfds_detection_decoders
from official.vision.beta.dataloaders import tfds_segmentation_decoders
from official.vision.beta.dataloaders import tfds_classification_decoders
def get_classification_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets classification decoder.
Args:
tfds_name: `str`, name of the tfds classification decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(
f'TFDS Classification {tfds_name} is not supported')
return decoder
def get_detection_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets detection decoder.
Args:
tfds_name: `str`, name of the tfds detection decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(f'TFDS Detection {tfds_name} is not supported')
return decoder
def get_segmentation_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets segmentation decoder.
Args:
tfds_name: `str`, name of the tfds segmentation decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(f'TFDS Segmentation {tfds_name} is not supported')
return decoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfds factory functions."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.dataloaders import decoder as base_decoder
from official.vision.beta.dataloaders import tfds_factory
class TFDSFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('imagenet2012'),
('cifar10'),
('cifar100'),
)
def test_classification_decoder(self, tfds_name):
decoder = tfds_factory.get_classification_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('flowers'),
('coco'),
)
def test_doesnt_exit_classification_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_classification_decoder(tfds_name)
@parameterized.parameters(
('coco'),
('coco/2014'),
('coco/2017'),
)
def test_detection_decoder(self, tfds_name):
decoder = tfds_factory.get_detection_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('pascal'),
('cityscapes'),
)
def test_doesnt_exit_detection_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_detection_decoder(tfds_name)
@parameterized.parameters(
('cityscapes'),
('cityscapes/semantic_segmentation'),
('cityscapes/semantic_segmentation_extra'),
)
def test_segmentation_decoder(self, tfds_name):
decoder = tfds_factory.get_segmentation_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('coco'),
('imagenet'),
)
def test_doesnt_exit_segmentation_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_segmentation_decoder(tfds_name)
if __name__ == '__main__':
tf.test.main()
...@@ -24,7 +24,7 @@ from official.modeling import tf_utils ...@@ -24,7 +24,7 @@ from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.dataloaders import classification_input from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_classification_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -89,11 +89,7 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_classification_decoder(params.tfds_name)
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = classification_input.Decoder( decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key, image_field_key=image_field_key, label_field_key=label_field_key,
......
...@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg ...@@ -25,7 +25,7 @@ from official.vision.beta.configs import retinanet as exp_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import retinanet_input from official.vision.beta.dataloaders import retinanet_input
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tfds_detection_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task): ...@@ -90,11 +90,7 @@ class RetinaNetTask(base_task.Task):
"""Build input dataset.""" """Build input dataset."""
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_detection_decoder(params.tfds_name)
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
......
...@@ -23,7 +23,7 @@ from official.core import task_factory ...@@ -23,7 +23,7 @@ from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import segmentation_input from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.dataloaders import tfds_segmentation_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -87,11 +87,7 @@ class SemanticSegmentationTask(base_task.Task):
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = segmentation_input.Decoder() decoder = segmentation_input.Decoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment