Commit 37103aec authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406239748
parent 067c893e
...@@ -27,6 +27,19 @@ from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config ...@@ -27,6 +27,19 @@ from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config
from official.projects.edgetpu.vision.tasks import image_classification from official.projects.edgetpu.vision.tasks import image_classification
# Dummy ImageNet TF dataset.
def dummy_imagenet_dataset():
def dummy_data(_):
dummy_image = tf.zeros((2, 224, 224, 3), dtype=tf.float32)
dummy_label = tf.zeros((2), dtype=tf.int32)
return (dummy_image, dummy_label)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('mobilenet_edgetpu_v2_xs'), @parameterized.parameters(('mobilenet_edgetpu_v2_xs'),
...@@ -45,10 +58,8 @@ class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -45,10 +58,8 @@ class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
task = image_classification.EdgeTPUTask(config.task) task = image_classification.EdgeTPUTask(config.task)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs, dataset = dummy_imagenet_dataset()
config.task.train_data)
iterator = iter(dataset) iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config) opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
......
...@@ -28,6 +28,25 @@ from official.projects.edgetpu.vision.tasks import semantic_segmentation as img_ ...@@ -28,6 +28,25 @@ from official.projects.edgetpu.vision.tasks import semantic_segmentation as img_
from official.vision import beta from official.vision import beta
# Dummy ADE20K TF dataset.
def dummy_ade20k_dataset(image_width, image_height):
def dummy_data(_):
dummy_image = tf.zeros((1, image_width, image_height, 3), dtype=tf.float32)
dummy_masks = tf.zeros((1, image_width, image_height, 1), dtype=tf.float32)
dummy_valid_masks = dummy_masks
dummy_image_info = tf.zeros((1, 4, 2), dtype=tf.float32)
return (dummy_image, {
'masks': dummy_masks,
'valid_masks': dummy_valid_masks,
'image_info': dummy_image_info,
})
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('deeplabv3plus_mobilenet_edgetpuv2_xs_ade20k_32',), @parameterized.parameters(('deeplabv3plus_mobilenet_edgetpuv2_xs_ade20k_32',),
...@@ -57,10 +76,8 @@ class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -57,10 +76,8 @@ class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
task = img_seg_task.CustomSemanticSegmentationTask(config.task) task = img_seg_task.CustomSemanticSegmentationTask(config.task)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs, dataset = dummy_ade20k_dataset(32, 32)
config.task.train_data)
iterator = iter(dataset) iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config) opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment