"src/torio/_extension/utils.py" did not exist on "c38229d447f55532a1bdffdbbe7832d16cd0de5e"
Commit 1bfe1df1 authored by Mark Sandler's avatar Mark Sandler Committed by Sergio Guadarrama
Browse files

PNasNet (#3736)

* PiperOrigin-RevId: 189857068

* PiperOrigin-RevId: 190089200

* Merge pull request #3702 from cclauss/from-six.moves-import-xrange-yet-again

from six.moves import xrange (en masse) YET AGAIN

PiperOrigin-RevId: 190255581

* I Fixes bunch of model tests that were using python2 functions.

II Updates mobilenet code:
1) Mobilenet usage example
2) Links to all checkpoints and updated README
3) Performance graphs

PiperOrigin-RevId: 190300379

* PiperOrigin-RevId: 190306214

* Updates notebook to reflect canonical repository location and fixes few
variable names.
parent 932364b6
...@@ -259,7 +259,8 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy | ...@@ -259,7 +259,8 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
[MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)|70.9|89.9| [MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)|70.9|89.9|
[MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)|59.1|81.9| [MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)|59.1|81.9|
[MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)|41.5|66.3| [MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)|41.5|66.3|
[MobileNet_v2_1.0_224^*](https://arxiv.org/abs/1801.04381)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py)|[Checkpoint TBA]()|72.2|91.0| [MobileNet_v2_1.4_224^*](https://arxiv.org/abs/1801.04381)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py)| [mobilenet_v2_1.4_224.tgz](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 74.9 | 92.5|
[MobileNet_v2_1.0_224^*](https://arxiv.org/abs/1801.04381)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py)| [mobilenet_v2_1.0_224.tgz](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 71.9 | 91.0
[NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6| [NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6|
[NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2| [NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2|
[PNASNet-5_Large_331](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_large_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_large_2017_12_13.tar.gz)|82.9|96.2| [PNASNet-5_Large_331](https://arxiv.org/abs/1712.00559)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/pnasnet.py)|[pnasnet-5_large_2017_12_13.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/pnasnet-5_large_2017_12_13.tar.gz)|82.9|96.2|
...@@ -275,7 +276,7 @@ All 16 float MobileNet V1 models reported in the [MobileNet Paper](https://arxiv ...@@ -275,7 +276,7 @@ All 16 float MobileNet V1 models reported in the [MobileNet Paper](https://arxiv
16 quantized [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) compatible MobileNet V1 models can be found 16 quantized [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) compatible MobileNet V1 models can be found
[here](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet_v1.md). [here](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet_v1.md).
(^#) More details on Mobilenet V2 models can be found [here](nets/mobilenet/README.md). (^#) More details on MobileNetV2 models can be found [here](nets/mobilenet/README.md).
(\*): Results quoted from the [paper](https://arxiv.org/abs/1603.05027). (\*): Results quoted from the [paper](https://arxiv.org/abs/1603.05027).
......
...@@ -93,9 +93,8 @@ import sys ...@@ -93,9 +93,8 @@ import sys
import threading import threading
import numpy as np import numpy as np
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange
tf.app.flags.DEFINE_string('train_directory', '/tmp/', tf.app.flags.DEFINE_string('train_directory', '/tmp/',
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Process the ImageNet Challenge bounding boxes for TensorFlow model training. r"""Process the ImageNet Challenge bounding boxes for TensorFlow model training.
Associate the ImageNet 2012 Challenge validation data set with labels. Associate the ImageNet 2012 Challenge validation data set with labels.
...@@ -51,7 +51,7 @@ from __future__ import print_function ...@@ -51,7 +51,7 @@ from __future__ import print_function
import os import os
import sys import sys
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -85,9 +85,7 @@ import glob ...@@ -85,9 +85,7 @@ import glob
import os.path import os.path
import sys import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import xrange
class BoundingBox(object): class BoundingBox(object):
......
...@@ -230,10 +230,11 @@ def _gather_clone_loss(clone, num_clones, regularization_losses): ...@@ -230,10 +230,11 @@ def _gather_clone_loss(clone, num_clones, regularization_losses):
sum_loss = tf.add_n(all_losses) sum_loss = tf.add_n(all_losses)
# Add the summaries out of the clone device block. # Add the summaries out of the clone device block.
if clone_loss is not None: if clone_loss is not None:
tf.summary.scalar(clone.scope + '/clone_loss', clone_loss, family='Losses') tf.summary.scalar('/'.join(filter(None,
['Losses', clone.scope, 'clone_loss'])),
clone_loss)
if regularization_loss is not None: if regularization_loss is not None:
tf.summary.scalar('regularization_loss', regularization_loss, tf.summary.scalar('Losses/regularization_loss', regularization_loss)
family='Losses')
return sum_loss return sum_loss
......
...@@ -18,9 +18,8 @@ from __future__ import division ...@@ -18,9 +18,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange
layers = tf.contrib.layers layers = tf.contrib.layers
......
...@@ -19,10 +19,9 @@ from __future__ import print_function ...@@ -19,10 +19,9 @@ from __future__ import print_function
from math import log from math import log
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange
slim = tf.contrib.slim slim = tf.contrib.slim
......
...@@ -18,9 +18,9 @@ from __future__ import absolute_import ...@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange
from nets import dcgan from nets import dcgan
......
...@@ -127,7 +127,7 @@ class InceptionTest(tf.test.TestCase): ...@@ -127,7 +127,7 @@ class InceptionTest(tf.test.TestCase):
'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a',
'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] 'Mixed_7b', 'Mixed_7c', 'Mixed_7d']
self.assertItemsEqual(end_points.keys(), expected_endpoints) self.assertItemsEqual(end_points.keys(), expected_endpoints)
for name, op in end_points.iteritems(): for name, op in end_points.items():
self.assertTrue(op.name.startswith('InceptionV4/' + name)) self.assertTrue(op.name.startswith('InceptionV4/' + name))
def testBuildOnlyUpToFinalEndpoint(self): def testBuildOnlyUpToFinalEndpoint(self):
......
# Mobilenet V2 # MobileNetV2
This folder contains building code for Mobilenet V2, based on This folder contains building code for MobileNetV2, based on
[Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation] [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)
(https://arxiv.org/abs/1801.04381)
# Pretrained model # Performance
TODO ## Latency
This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using
TF-Lite on the large core of Pixel 1 phone.
![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png)
## MACs
MACs, also sometimes known as MADDs - the number of multiply-accumulates needed
to compute an inference on a single image is a common metric to measure the efficiency of the model.
Below is the graph comparing V2 vs a few selected networks. The size
of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there
are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers.
![madds_top1_accuracy](madds_top1_accuracy.png)
# Pretrained models
## Imagenet Checkpoints
Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1
---------------------------|---------|---------------|---------|----|-------------
| [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0
| [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0
| [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8
| [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1
| [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2
| [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6
| [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6
| [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8
| [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6
| [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4
| [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9
| [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2
| [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7
| [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1
| [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9
| [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9
| [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4
| [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7
| [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6
| [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5
| [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9
| [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5
# Example # Example
TODO
See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb).
...@@ -81,7 +81,7 @@ def _set_arg_scope_defaults(defaults): ...@@ -81,7 +81,7 @@ def _set_arg_scope_defaults(defaults):
context manager where all defaults are set. context manager where all defaults are set.
""" """
if hasattr(defaults, 'items'): if hasattr(defaults, 'items'):
items = defaults.items() items = list(defaults.items())
else: else:
items = defaults items = defaults
if not items: if not items:
......
This diff is collapsed.
...@@ -117,7 +117,7 @@ def mobilenet(input_tensor, ...@@ -117,7 +117,7 @@ def mobilenet(input_tensor,
divisible_by: If provided will ensure that all layers # channels divisible_by: If provided will ensure that all layers # channels
will be divisible by this number. will be divisible by this number.
**kwargs: passed directly to mobilenet.mobilenet: **kwargs: passed directly to mobilenet.mobilenet:
prediciton_fn- what prediction function to use. prediction_fn- what prediction function to use.
reuse-: whether to reuse variables (if reuse set to true, scope reuse-: whether to reuse variables (if reuse set to true, scope
must be given). must be given).
Returns: Returns:
......
# Mobilenet_v2
For Mobilenet V2 see this file [mobilenet/README.md]
# MobileNet_v1 # MobileNet_v1
[MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/). [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
......
...@@ -168,13 +168,13 @@ class MobilenetV1Test(tf.test.TestCase): ...@@ -168,13 +168,13 @@ class MobilenetV1Test(tf.test.TestCase):
'Conv2d_13_depthwise': [batch_size, 7, 7, 1024], 'Conv2d_13_depthwise': [batch_size, 7, 7, 1024],
'Conv2d_13_pointwise': [batch_size, 7, 7, 1024]} 'Conv2d_13_pointwise': [batch_size, 7, 7, 1024]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in end_points) self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape) expected_shape)
self.assertItemsEqual(endpoints_shapes.keys(), self.assertItemsEqual(endpoints_shapes.keys(),
explicit_padding_end_points.keys()) explicit_padding_end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in explicit_padding_end_points) self.assertTrue(endpoint_name in explicit_padding_end_points)
self.assertListEqual( self.assertListEqual(
explicit_padding_end_points[endpoint_name].get_shape().as_list(), explicit_padding_end_points[endpoint_name].get_shape().as_list(),
...@@ -222,13 +222,13 @@ class MobilenetV1Test(tf.test.TestCase): ...@@ -222,13 +222,13 @@ class MobilenetV1Test(tf.test.TestCase):
'Conv2d_13_depthwise': [batch_size, 14, 14, 1024], 'Conv2d_13_depthwise': [batch_size, 14, 14, 1024],
'Conv2d_13_pointwise': [batch_size, 14, 14, 1024]} 'Conv2d_13_pointwise': [batch_size, 14, 14, 1024]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in end_points) self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape) expected_shape)
self.assertItemsEqual(endpoints_shapes.keys(), self.assertItemsEqual(endpoints_shapes.keys(),
explicit_padding_end_points.keys()) explicit_padding_end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in explicit_padding_end_points) self.assertTrue(endpoint_name in explicit_padding_end_points)
self.assertListEqual( self.assertListEqual(
explicit_padding_end_points[endpoint_name].get_shape().as_list(), explicit_padding_end_points[endpoint_name].get_shape().as_list(),
...@@ -276,13 +276,13 @@ class MobilenetV1Test(tf.test.TestCase): ...@@ -276,13 +276,13 @@ class MobilenetV1Test(tf.test.TestCase):
'Conv2d_13_depthwise': [batch_size, 28, 28, 1024], 'Conv2d_13_depthwise': [batch_size, 28, 28, 1024],
'Conv2d_13_pointwise': [batch_size, 28, 28, 1024]} 'Conv2d_13_pointwise': [batch_size, 28, 28, 1024]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in end_points) self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape) expected_shape)
self.assertItemsEqual(endpoints_shapes.keys(), self.assertItemsEqual(endpoints_shapes.keys(),
explicit_padding_end_points.keys()) explicit_padding_end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in explicit_padding_end_points) self.assertTrue(endpoint_name in explicit_padding_end_points)
self.assertListEqual( self.assertListEqual(
explicit_padding_end_points[endpoint_name].get_shape().as_list(), explicit_padding_end_points[endpoint_name].get_shape().as_list(),
...@@ -329,13 +329,13 @@ class MobilenetV1Test(tf.test.TestCase): ...@@ -329,13 +329,13 @@ class MobilenetV1Test(tf.test.TestCase):
'Conv2d_13_depthwise': [batch_size, 4, 4, 768], 'Conv2d_13_depthwise': [batch_size, 4, 4, 768],
'Conv2d_13_pointwise': [batch_size, 4, 4, 768]} 'Conv2d_13_pointwise': [batch_size, 4, 4, 768]}
self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in end_points) self.assertTrue(endpoint_name in end_points)
self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
expected_shape) expected_shape)
self.assertItemsEqual(endpoints_shapes.keys(), self.assertItemsEqual(endpoints_shapes.keys(),
explicit_padding_end_points.keys()) explicit_padding_end_points.keys())
for endpoint_name, expected_shape in endpoints_shapes.iteritems(): for endpoint_name, expected_shape in endpoints_shapes.items():
self.assertTrue(endpoint_name in explicit_padding_end_points) self.assertTrue(endpoint_name in explicit_padding_end_points)
self.assertListEqual( self.assertListEqual(
explicit_padding_end_points[endpoint_name].get_shape().as_list(), explicit_padding_end_points[endpoint_name].get_shape().as_list(),
......
...@@ -30,7 +30,7 @@ class NetworksTest(tf.test.TestCase): ...@@ -30,7 +30,7 @@ class NetworksTest(tf.test.TestCase):
def testGetNetworkFnFirstHalf(self): def testGetNetworkFnFirstHalf(self):
batch_size = 5 batch_size = 5
num_classes = 1000 num_classes = 1000
for net in nets_factory.networks_map.keys()[:10]: for net in list(nets_factory.networks_map.keys())[:10]:
with tf.Graph().as_default() as g, self.test_session(g): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes)
# Most networks use 224 as their default_image_size # Most networks use 224 as their default_image_size
...@@ -45,7 +45,7 @@ class NetworksTest(tf.test.TestCase): ...@@ -45,7 +45,7 @@ class NetworksTest(tf.test.TestCase):
def testGetNetworkFnSecondHalf(self): def testGetNetworkFnSecondHalf(self):
batch_size = 5 batch_size = 5
num_classes = 1000 num_classes = 1000
for net in nets_factory.networks_map.keys()[10:]: for net in list(nets_factory.networks_map.keys())[10:]:
with tf.Graph().as_default() as g, self.test_session(g): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes)
# Most networks use 224 as their default_image_size # Most networks use 224 as their default_image_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment