Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2d40f27a
Commit
2d40f27a
authored
Oct 20, 2021
by
A. Unique TensorFlower
Browse files
Adds a TFLite classification accuracy evaluator tool.
PiperOrigin-RevId: 404581314
parent
6ff62233
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
235 additions
and
0 deletions
+235
-0
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator.py
...jects/edgetpu/vision/serving/tflite_imagenet_evaluator.py
+105
-0
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_run.py
...s/edgetpu/vision/serving/tflite_imagenet_evaluator_run.py
+71
-0
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
.../edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
+59
-0
No files found.
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator.py
0 → 100644
View file @
2d40f27a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluates image classification accuracy using TFLite Interpreter."""
import
dataclasses
import
multiprocessing.pool
as
mp
from
typing
import
Tuple
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
@
dataclasses
.
dataclass
class
EvaluationInput
():
"""Contains image and its label as evaluation input."""
image
:
tf
.
Tensor
label
:
tf
.
Tensor
class
AccuracyEvaluator
():
"""Evaluates image classification accuracy using TFLite Interpreter.
Attributes:
model_content: The contents of a TFLite model.
num_threads: Number of threads used to evaluate images.
thread_batch_size: Batch size assigned to each thread.
image_size: Width/Height of the images.
num_classes: Number of classes predicted by the model.
resize_method: Resize method to use during image preprocessing.
"""
def
__init__
(
self
,
model_content
:
bytes
,
dataset
:
tf
.
data
.
Dataset
,
num_threads
:
int
=
16
):
self
.
_model_content
:
bytes
=
model_content
self
.
_dataset
=
dataset
self
.
_num_threads
:
int
=
num_threads
def
evaluate_single_image
(
self
,
eval_input
:
EvaluationInput
)
->
bool
:
"""Evaluates a given single input.
Args:
eval_input: EvaluationInput holding image and label.
Returns:
Whether the estimation is correct.
"""
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
self
.
_model_content
,
num_threads
=
1
)
interpreter
.
allocate_tensors
()
# Get input and output tensors and quantization details.
input_details
=
interpreter
.
get_input_details
()
output_details
=
interpreter
.
get_output_details
()
image_tensor
=
interpreter
.
tensor
(
input_details
[
0
][
'index'
])
logits_tensor
=
interpreter
.
tensor
(
output_details
[
0
][
'index'
])
# Handle quantization.
scale
=
1.0
zero_point
=
0.0
input_dtype
=
tf
.
as_dtype
(
input_details
[
0
][
'dtype'
])
if
input_dtype
.
is_quantized
or
input_dtype
.
is_integer
:
input_quantization
=
input_details
[
0
][
'quantization'
]
scale
=
input_quantization
[
0
]
zero_point
=
input_quantization
[
1
]
image_tensor
()[
0
,
:]
=
(
eval_input
.
image
.
numpy
()
/
scale
)
+
zero_point
interpreter
.
invoke
()
return
eval_input
.
label
.
numpy
()
==
np
.
argmax
(
logits_tensor
()[
0
])
def
evaluate_all
(
self
)
->
Tuple
[
int
,
int
]:
"""Evaluates all of images in the default dataset.
Returns:
Total number of evaluations and correct predictions as tuple of ints.
"""
num_evals
=
0
num_corrects
=
0
for
image_batch
,
label_batch
in
self
.
_dataset
:
inputs
=
[
EvaluationInput
(
image
,
label
)
for
image
,
label
in
zip
(
image_batch
,
label_batch
)
]
pool
=
mp
.
ThreadPool
(
self
.
_num_threads
)
results
=
pool
.
map
(
self
.
evaluate_single_image
,
inputs
)
pool
.
close
()
pool
.
join
()
num_evals
+=
len
(
results
)
num_corrects
+=
results
.
count
(
True
)
accuracy
=
100.0
*
num_corrects
/
num_evals
if
num_evals
>
0
else
0
logging
.
info
(
'Evaluated: %d, Correct: %d, Accuracy: %f'
,
num_evals
,
num_corrects
,
accuracy
)
return
(
num_evals
,
num_corrects
)
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_run.py
0 → 100644
View file @
2d40f27a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Evaluates image classification accuracy using tflite_imagenet_evaluator.
Usage:
tflite_imagenet_evaluator_run --tflite_model_path=/PATH/TO/MODEL.tflite
"""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.projects.edgetpu.vision.serving
import
tflite_imagenet_evaluator
from
official.projects.edgetpu.vision.tasks
import
image_classification
flags
.
DEFINE_string
(
'tflite_model_path'
,
None
,
'Path to the tflite file to be evaluated.'
)
flags
.
DEFINE_integer
(
'num_threads'
,
16
,
'Number of local threads.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
256
,
'Batch size per thread.'
)
flags
.
DEFINE_string
(
'model_name'
,
'mobilenet_edgetpu_v2_xs'
,
'Model name to identify a registered data pipeline setup and use as the '
'validation dataset.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
argv
:
Sequence
[
str
]):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
tflite_model_path
,
'rb'
)
as
f
:
model_content
=
f
.
read
()
config
=
exp_factory
.
get_exp_config
(
FLAGS
.
model_name
)
global_batch_size
=
FLAGS
.
num_threads
*
FLAGS
.
batch_size
config
.
task
.
validation_data
.
global_batch_size
=
global_batch_size
config
.
task
.
validation_data
.
dtype
=
'float32'
task
=
image_classification
.
EdgeTPUTask
(
config
.
task
)
dataset
=
task
.
build_inputs
(
config
.
task
.
validation_data
)
evaluator
=
tflite_imagenet_evaluator
.
AccuracyEvaluator
(
model_content
=
model_content
,
dataset
=
dataset
,
num_threads
=
FLAGS
.
num_threads
)
evals
,
corrects
=
evaluator
.
evaluate_all
()
accuracy
=
100.0
*
corrects
/
evals
if
evals
>
0
else
0
print
(
'Final accuracy: {}, Evaluated: {}, Correct: {} '
.
format
(
accuracy
,
evals
,
corrects
))
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'tflite_model_path'
)
app
.
run
(
main
)
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
0 → 100644
View file @
2d40f27a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tflite_imagenet_evaluator."""
from
unittest
import
mock
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.projects.edgetpu.vision.serving
import
tflite_imagenet_evaluator
from
official.projects.edgetpu.vision.tasks
import
image_classification
class
TfliteImagenetEvaluatorTest
(
tf
.
test
.
TestCase
):
# Only tests the parallelization aspect. Mocks image evaluation and dataset.
def
test_evaluate_all
(
self
):
batch_size
=
8
num_threads
=
4
global_batch_size
=
num_threads
*
batch_size
config
=
exp_factory
.
get_exp_config
(
'mobilenet_edgetpu_v2_xs'
)
config
.
task
.
validation_data
.
global_batch_size
=
global_batch_size
config
.
task
.
validation_data
.
dtype
=
'float32'
task
=
image_classification
.
EdgeTPUTask
(
config
.
task
)
dataset
=
task
.
build_inputs
(
config
.
task
.
validation_data
)
num_batches
=
5
with
mock
.
patch
.
object
(
tflite_imagenet_evaluator
.
AccuracyEvaluator
,
'evaluate_single_image'
,
return_value
=
True
,
autospec
=
True
):
evaluator
=
tflite_imagenet_evaluator
.
AccuracyEvaluator
(
model_content
=
'MockModelContent'
.
encode
(
'utf-8'
),
dataset
=
dataset
.
take
(
num_batches
),
num_threads
=
num_threads
)
num_evals
,
num_corrects
=
evaluator
.
evaluate_all
()
expected_evals
=
num_batches
*
num_threads
*
batch_size
self
.
assertEqual
(
num_evals
,
expected_evals
)
self
.
assertEqual
(
num_corrects
,
expected_evals
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment