Commit 6cd536dc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405448944
parent 88eac22a
......@@ -17,9 +17,7 @@
from unittest import mock
import tensorflow as tf
from official.core import exp_factory
from official.projects.edgetpu.vision.serving import tflite_imagenet_evaluator
from official.projects.edgetpu.vision.tasks import image_classification
class TfliteImagenetEvaluatorTest(tf.test.TestCase):
......@@ -28,16 +26,13 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
def test_evaluate_all(self):
batch_size = 8
num_threads = 4
global_batch_size = num_threads * batch_size
config = exp_factory.get_exp_config('mobilenet_edgetpu_v2_xs')
config.task.validation_data.global_batch_size = global_batch_size
config.task.validation_data.dtype = 'float32'
num_batches = 5
task = image_classification.EdgeTPUTask(config.task)
dataset = task.build_inputs(config.task.validation_data)
labels = tf.data.Dataset.range(batch_size * num_threads * num_batches)
images = tf.data.Dataset.range(batch_size * num_threads * num_batches)
dataset = tf.data.Dataset.zip((images, labels))
dataset = dataset.batch(batch_size)
num_batches = 5
with mock.patch.object(
tflite_imagenet_evaluator.AccuracyEvaluator,
'evaluate_single_image',
......@@ -45,7 +40,7 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
autospec=True):
evaluator = tflite_imagenet_evaluator.AccuracyEvaluator(
model_content='MockModelContent'.encode('utf-8'),
dataset=dataset.take(num_batches),
dataset=dataset,
num_threads=num_threads)
num_evals, num_corrects = evaluator.evaluate_all()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment