Commit 580aa3f6 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 439718605
parent 49b61fe4
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""Tests for image classification task.""" """Tests for image classification task."""
# pylint: disable=unused-import # pylint: disable=unused-import
import os
from absl.testing import parameterized from absl.testing import parameterized
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -23,16 +25,31 @@ from official import vision ...@@ -23,16 +25,31 @@ from official import vision
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.qat.vision.tasks import image_classification as img_cls_task from official.projects.qat.vision.tasks import image_classification as img_cls_task
from official.vision.dataloaders import tfexample_utils
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
@parameterized.parameters(('resnet_imagenet_qat'), @parameterized.parameters(('resnet_imagenet_qat'),
('mobilenet_imagenet_qat')) ('mobilenet_imagenet_qat'))
def test_task(self, config_name): def test_task(self, config_name):
input_image_size = [224, 224]
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2 config.task.train_data.global_batch_size = 2
config.task.validation_data.input_path = test_tfrecord_file
config.task.train_data.input_path = test_tfrecord_file
task = img_cls_task.ImageClassificationTask(config.task) task = img_cls_task.ImageClassificationTask(config.task)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Tests for RetinaNet task.""" """Tests for RetinaNet task."""
# pylint: disable=unused-import # pylint: disable=unused-import
import os
from absl.testing import parameterized from absl.testing import parameterized
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -22,17 +24,32 @@ from official import vision ...@@ -22,17 +24,32 @@ from official import vision
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.qat.vision.tasks import retinanet from official.projects.qat.vision.tasks import retinanet
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.configs import retinanet as exp_cfg from official.vision.configs import retinanet as exp_cfg
class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase): class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
@parameterized.parameters( @parameterized.parameters(
('retinanet_spinenet_mobile_coco_qat', True), ('retinanet_spinenet_mobile_coco_qat', True),
('retinanet_spinenet_mobile_coco_qat', False), ('retinanet_spinenet_mobile_coco_qat', False),
) )
def test_retinanet_task(self, test_config, is_training): def test_retinanet_task(self, test_config, is_training):
"""RetinaNet task test for training and val using toy configs.""" """RetinaNet task test for training and val using toy configs."""
input_image_size = [384, 384]
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
config = exp_factory.get_exp_config(test_config) config = exp_factory.get_exp_config(test_config)
# modify config to suit local testing # modify config to suit local testing
config.task.model.input_size = [128, 128, 3] config.task.model.input_size = [128, 128, 3]
...@@ -41,6 +58,8 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase): ...@@ -41,6 +58,8 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
config.task.validation_data.global_batch_size = 1 config.task.validation_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 2 config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2 config.task.validation_data.shuffle_buffer_size = 2
config.task.validation_data.input_path = test_tfrecord_file
config.task.train_data.input_path = test_tfrecord_file
config.train_steps = 1 config.train_steps = 1
task = retinanet.RetinaNetTask(config.task) task = retinanet.RetinaNetTask(config.task)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment