"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "d88972ea48cfec20ebba6e0a86a825fca3ecb193"
Commit 2d40f27a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adds a TFLite classification accuracy evaluator tool.

PiperOrigin-RevId: 404581314
parent 6ff62233
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluates image classification accuracy using TFLite Interpreter."""
import dataclasses
import multiprocessing.pool as mp
from typing import Tuple
from absl import logging
import numpy as np
import tensorflow as tf
@dataclasses.dataclass
class EvaluationInput():
"""Contains image and its label as evaluation input."""
image: tf.Tensor
label: tf.Tensor
class AccuracyEvaluator():
"""Evaluates image classification accuracy using TFLite Interpreter.
Attributes:
model_content: The contents of a TFLite model.
num_threads: Number of threads used to evaluate images.
thread_batch_size: Batch size assigned to each thread.
image_size: Width/Height of the images.
num_classes: Number of classes predicted by the model.
resize_method: Resize method to use during image preprocessing.
"""
def __init__(self,
model_content: bytes,
dataset: tf.data.Dataset,
num_threads: int = 16):
self._model_content: bytes = model_content
self._dataset = dataset
self._num_threads: int = num_threads
def evaluate_single_image(self, eval_input: EvaluationInput) -> bool:
"""Evaluates a given single input.
Args:
eval_input: EvaluationInput holding image and label.
Returns:
Whether the estimation is correct.
"""
interpreter = tf.lite.Interpreter(
model_content=self._model_content, num_threads=1)
interpreter.allocate_tensors()
# Get input and output tensors and quantization details.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
image_tensor = interpreter.tensor(input_details[0]['index'])
logits_tensor = interpreter.tensor(output_details[0]['index'])
# Handle quantization.
scale = 1.0
zero_point = 0.0
input_dtype = tf.as_dtype(input_details[0]['dtype'])
if input_dtype.is_quantized or input_dtype.is_integer:
input_quantization = input_details[0]['quantization']
scale = input_quantization[0]
zero_point = input_quantization[1]
image_tensor()[0, :] = (eval_input.image.numpy() / scale) + zero_point
interpreter.invoke()
return eval_input.label.numpy() == np.argmax(logits_tensor()[0])
def evaluate_all(self) -> Tuple[int, int]:
"""Evaluates all of images in the default dataset.
Returns:
Total number of evaluations and correct predictions as tuple of ints.
"""
num_evals = 0
num_corrects = 0
for image_batch, label_batch in self._dataset:
inputs = [
EvaluationInput(image, label)
for image, label in zip(image_batch, label_batch)
]
pool = mp.ThreadPool(self._num_threads)
results = pool.map(self.evaluate_single_image, inputs)
pool.close()
pool.join()
num_evals += len(results)
num_corrects += results.count(True)
accuracy = 100.0 * num_corrects / num_evals if num_evals > 0 else 0
logging.info('Evaluated: %d, Correct: %d, Accuracy: %f', num_evals,
num_corrects, accuracy)
return (num_evals, num_corrects)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates image classification accuracy using tflite_imagenet_evaluator.
Usage:
tflite_imagenet_evaluator_run --tflite_model_path=/PATH/TO/MODEL.tflite
"""
from typing import Sequence
from absl import app
from absl import flags
import tensorflow as tf
from official.core import exp_factory
from official.projects.edgetpu.vision.serving import tflite_imagenet_evaluator
from official.projects.edgetpu.vision.tasks import image_classification
flags.DEFINE_string('tflite_model_path', None,
'Path to the tflite file to be evaluated.')
flags.DEFINE_integer('num_threads', 16, 'Number of local threads.')
flags.DEFINE_integer('batch_size', 256, 'Batch size per thread.')
flags.DEFINE_string(
'model_name', 'mobilenet_edgetpu_v2_xs',
'Model name to identify a registered data pipeline setup and use as the '
'validation dataset.')
FLAGS = flags.FLAGS
def main(argv: Sequence[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
with tf.io.gfile.GFile(FLAGS.tflite_model_path, 'rb') as f:
model_content = f.read()
config = exp_factory.get_exp_config(FLAGS.model_name)
global_batch_size = FLAGS.num_threads * FLAGS.batch_size
config.task.validation_data.global_batch_size = global_batch_size
config.task.validation_data.dtype = 'float32'
task = image_classification.EdgeTPUTask(config.task)
dataset = task.build_inputs(config.task.validation_data)
evaluator = tflite_imagenet_evaluator.AccuracyEvaluator(
model_content=model_content,
dataset=dataset,
num_threads=FLAGS.num_threads)
evals, corrects = evaluator.evaluate_all()
accuracy = 100.0 * corrects / evals if evals > 0 else 0
print('Final accuracy: {}, Evaluated: {}, Correct: {} '.format(
accuracy, evals, corrects))
if __name__ == '__main__':
flags.mark_flag_as_required('tflite_model_path')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tflite_imagenet_evaluator."""
from unittest import mock
import tensorflow as tf
from official.core import exp_factory
from official.projects.edgetpu.vision.serving import tflite_imagenet_evaluator
from official.projects.edgetpu.vision.tasks import image_classification
class TfliteImagenetEvaluatorTest(tf.test.TestCase):
# Only tests the parallelization aspect. Mocks image evaluation and dataset.
def test_evaluate_all(self):
batch_size = 8
num_threads = 4
global_batch_size = num_threads * batch_size
config = exp_factory.get_exp_config('mobilenet_edgetpu_v2_xs')
config.task.validation_data.global_batch_size = global_batch_size
config.task.validation_data.dtype = 'float32'
task = image_classification.EdgeTPUTask(config.task)
dataset = task.build_inputs(config.task.validation_data)
num_batches = 5
with mock.patch.object(
tflite_imagenet_evaluator.AccuracyEvaluator,
'evaluate_single_image',
return_value=True,
autospec=True):
evaluator = tflite_imagenet_evaluator.AccuracyEvaluator(
model_content='MockModelContent'.encode('utf-8'),
dataset=dataset.take(num_batches),
num_threads=num_threads)
num_evals, num_corrects = evaluator.evaluate_all()
expected_evals = num_batches * num_threads * batch_size
self.assertEqual(num_evals, expected_evals)
self.assertEqual(num_corrects, expected_evals)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment