Commit 6cd536dc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405448944
parent 88eac22a
...@@ -17,9 +17,7 @@ ...@@ -17,9 +17,7 @@
from unittest import mock from unittest import mock
import tensorflow as tf import tensorflow as tf
from official.core import exp_factory
from official.projects.edgetpu.vision.serving import tflite_imagenet_evaluator from official.projects.edgetpu.vision.serving import tflite_imagenet_evaluator
from official.projects.edgetpu.vision.tasks import image_classification
class TfliteImagenetEvaluatorTest(tf.test.TestCase): class TfliteImagenetEvaluatorTest(tf.test.TestCase):
...@@ -28,16 +26,13 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase): ...@@ -28,16 +26,13 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
def test_evaluate_all(self): def test_evaluate_all(self):
batch_size = 8 batch_size = 8
num_threads = 4 num_threads = 4
global_batch_size = num_threads * batch_size num_batches = 5
config = exp_factory.get_exp_config('mobilenet_edgetpu_v2_xs')
config.task.validation_data.global_batch_size = global_batch_size
config.task.validation_data.dtype = 'float32'
task = image_classification.EdgeTPUTask(config.task) labels = tf.data.Dataset.range(batch_size * num_threads * num_batches)
dataset = task.build_inputs(config.task.validation_data) images = tf.data.Dataset.range(batch_size * num_threads * num_batches)
dataset = tf.data.Dataset.zip((images, labels))
dataset = dataset.batch(batch_size)
num_batches = 5
with mock.patch.object( with mock.patch.object(
tflite_imagenet_evaluator.AccuracyEvaluator, tflite_imagenet_evaluator.AccuracyEvaluator,
'evaluate_single_image', 'evaluate_single_image',
...@@ -45,7 +40,7 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase): ...@@ -45,7 +40,7 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
autospec=True): autospec=True):
evaluator = tflite_imagenet_evaluator.AccuracyEvaluator( evaluator = tflite_imagenet_evaluator.AccuracyEvaluator(
model_content='MockModelContent'.encode('utf-8'), model_content='MockModelContent'.encode('utf-8'),
dataset=dataset.take(num_batches), dataset=dataset,
num_threads=num_threads) num_threads=num_threads)
num_evals, num_corrects = evaluator.evaluate_all() num_evals, num_corrects = evaluator.evaluate_all()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment