Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6cd536dc
Commit
6cd536dc
authored
Oct 25, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 405448944
parent
88eac22a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
11 deletions
+6
-11
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
.../edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
+6
-11
No files found.
official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator_test.py
View file @
6cd536dc
...
@@ -17,9 +17,7 @@
...
@@ -17,9 +17,7 @@
from
unittest
import
mock
from
unittest
import
mock
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.projects.edgetpu.vision.serving
import
tflite_imagenet_evaluator
from
official.projects.edgetpu.vision.serving
import
tflite_imagenet_evaluator
from
official.projects.edgetpu.vision.tasks
import
image_classification
class
TfliteImagenetEvaluatorTest
(
tf
.
test
.
TestCase
):
class
TfliteImagenetEvaluatorTest
(
tf
.
test
.
TestCase
):
...
@@ -28,16 +26,13 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
...
@@ -28,16 +26,13 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
def
test_evaluate_all
(
self
):
def
test_evaluate_all
(
self
):
batch_size
=
8
batch_size
=
8
num_threads
=
4
num_threads
=
4
global_batch_size
=
num_threads
*
batch_size
num_batches
=
5
config
=
exp_factory
.
get_exp_config
(
'mobilenet_edgetpu_v2_xs'
)
config
.
task
.
validation_data
.
global_batch_size
=
global_batch_size
config
.
task
.
validation_data
.
dtype
=
'float32'
task
=
image_classification
.
EdgeTPUTask
(
config
.
task
)
labels
=
tf
.
data
.
Dataset
.
range
(
batch_size
*
num_threads
*
num_batches
)
dataset
=
task
.
build_inputs
(
config
.
task
.
validation_data
)
images
=
tf
.
data
.
Dataset
.
range
(
batch_size
*
num_threads
*
num_batches
)
dataset
=
tf
.
data
.
Dataset
.
zip
((
images
,
labels
))
dataset
=
dataset
.
batch
(
batch_size
)
num_batches
=
5
with
mock
.
patch
.
object
(
with
mock
.
patch
.
object
(
tflite_imagenet_evaluator
.
AccuracyEvaluator
,
tflite_imagenet_evaluator
.
AccuracyEvaluator
,
'evaluate_single_image'
,
'evaluate_single_image'
,
...
@@ -45,7 +40,7 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
...
@@ -45,7 +40,7 @@ class TfliteImagenetEvaluatorTest(tf.test.TestCase):
autospec
=
True
):
autospec
=
True
):
evaluator
=
tflite_imagenet_evaluator
.
AccuracyEvaluator
(
evaluator
=
tflite_imagenet_evaluator
.
AccuracyEvaluator
(
model_content
=
'MockModelContent'
.
encode
(
'utf-8'
),
model_content
=
'MockModelContent'
.
encode
(
'utf-8'
),
dataset
=
dataset
.
take
(
num_batches
)
,
dataset
=
dataset
,
num_threads
=
num_threads
)
num_threads
=
num_threads
)
num_evals
,
num_corrects
=
evaluator
.
evaluate_all
()
num_evals
,
num_corrects
=
evaluator
.
evaluate_all
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment