Commit 0ba83cf0 authored by pkulzc's avatar pkulzc Committed by Sergio Guadarrama
Browse files

Release MobileNet V3 models and SSDLite models with MobileNet V3 backbone. (#7678)

* Merged commit includes the following changes:
275131829  by Sergio Guadarrama:

    updates mobilenet/README.md to be github compatible adds V2+ reference to mobilenet_v1.md file and fixes invalid markdown

--
274908068  by Sergio Guadarrama:

    Opensource MobilenetV3 detection models.

--
274697808  by Sergio Guadarrama:

    Fixed cases where tf.TensorShape was constructed with float dimensions

    This is a prerequisite for making TensorShape and Dimension more strict
    about the types of their arguments.

--
273577462  by Sergio Guadarrama:

    Fixing `conv_defs['defaults']` override issue.

--
272801298  by Sergio Guadarrama:

    Adds links to trained models for Moblienet V3, adds a version of minimalistic mobilenet-v3 to the definitions.

--
268928503  by Sergio Guadarrama:

    Mobilenet v2 with group normalization.

--
263492735  by Sergio Guadarrama:

    Internal change

260037126  by Sergio Guadarrama:

    Adds an option of using a custom depthwise operation in `expanded_conv`.

--
259997001  by Sergio Guadarrama:

    Explicitly mark Python binaries/tests with python_version = "PY2".

--
252697685  by Sergio Guadarrama:

    Internal change

251918746  by Sergio Guadarrama:

    Internal change

251909704  by Sergio Guadarrama:

    Mobilenet V3 backbone implementation.

--
247510236  by Sergio Guadarrama:

    Internal change

246196802  by Sergio Guadarrama:

    Internal change

246014539  by Sergio Guadarrama:

    Internal change

245891435  by Sergio Guadarrama:

    Internal change

245834925  by Sergio Guadarrama:

    n/a

--

PiperOrigin-RevId: 275131829

* Merged commit includes the following changes:
274959989  by Zhichao Lu:

    Update detection model zoo with MobilenetV3 SSD candidates.

--
274908068  by Zhichao Lu:

    Opensource MobilenetV3 detection models.

--
274695889  by richardmunoz:

    RandomPatchGaussian preprocessing step

    This step can be used during model training to randomly apply gaussian noise to a random image patch. Example addition to an Object Detection API pipeline config:

    train_config {
      ...
      data_augmentation_options {
        random_patch_gaussian {
          random_coef: 0.5
          min_patch_size: 1
          max_patch_size: 250
          min_gaussian_stddev: 0.0
          max_gaussian_stddev: 1.0
        }
      }
      ...
    }

--
274257872  by lzc:

    Internal change.

--
274114689  by Zhichao Lu:

    Pass native_resize flag to other FPN variants.

--
274112308  by lzc:

    Internal change.

--
274090763  by richardmunoz:

    Util function for getting a patch mask on an image for use with the Object Detection API

--
274069806  by Zhichao Lu:

    Adding functions which will help compute predictions and losses for CenterNet.

--
273860828  by lzc:

    Internal change.

--
273380069  by richardmunoz:

    RandomImageDownscaleToTargetPixels preprocessing step

    This step can be used during model training to randomly downscale an image to a random target number of pixels. If the image does not contain more than the target number of pixels, then downscaling is skipped. Example addition to an Object Detection API pipeline config:

    train_config {
      ...
      data_augmentation_options {
        random_downscale_to_target_pixels {
          random_coef: 0.5
          min_target_pixels: 300000
          max_target_pixels: 500000
        }
      }
      ...
    }

--
272987602  by Zhichao Lu:

    Avoid -inf when empty box list is passed.

--
272525836  by Zhichao Lu:

    Cleanup repeated resizing code in meta archs.

--
272458667  by richardmunoz:

    RandomJpegQuality preprocessing step

    This step can be used during model training to randomly encode the image into a jpeg with a random quality level. Example addition to an Object Detection API pipeline config:

    train_config {
      ...
      data_augmentation_options {
        random_jpeg_quality {
          random_coef: 0.5
          min_jpeg_quality: 80
          max_jpeg_quality: 100
        }
      }
      ...
    }

--
271412717  by Zhichao Lu:

    Enables TPU training with the V2 eager + tf.function Object Detection training loops.

--
270744153  by Zhichao Lu:

    Adding the offset and size target assigners for CenterNet.

--
269916081  by Zhichao Lu:

    Include basic installation in Object Detection API tutorial.
    Also:
     - Use TF2.0
     - Use saved_model

--
269376056  by Zhichao Lu:

    Fix to variable loading in RetinaNet w/ custom loops. (makes the code rely on the exact name scopes that are generated a little bit less)

--
269256251  by lzc:

    Add use_partitioned_nms field to config and update post_prossing_builder to honor that flag when building nms function.

--
268865295  by Zhichao Lu:

    Adding functionality for importing and merging back internal state of the metric.

--
268640984  by Zhichao Lu:

    Fix computation of gaussian sigma value to create CenterNet heatmap target.

--
267475576  by Zhichao Lu:

    Fix for exporter trying to export non-existent exponential moving averages.

--
267286768  by Zhichao Lu:

    Update mixed-precision policy.

--
266166879  by Zhichao Lu:

    Internal change

265860884  by Zhichao Lu:

    Apply floor function to center coordinates when creating heatmap for CenterNet target.

--
265702749  by Zhichao Lu:

    Internal change

--
264241949  by ronnyvotel:

    Updating Faster R-CNN 'final_anchors' to be in normalized coordinates.

--
264175192  by lzc:

    Update model_fn to only read hparams if it is not None.

--
264159328  by Zhichao Lu:

    Modify nearest neighbor upsampling to eliminate a multiply operation. For quantized models, the multiply operation gets unnecessarily quantized and reduces accuracy (simple stacking would work in place of the broadcast op which doesn't require quantization). Also removes an unnecessary reshape op.

--
263668306  by Zhichao Lu:

    Add the option to use dynamic map_fn for batch NMS

--
263031163  by Zhichao Lu:

    Mark outside compilation for NMS as optional.

--
263024916  by Zhichao Lu:

    Add an ExperimentalModel meta arch for experimenting with new model types.

--
262655894  by Zhichao Lu:

    Add the center heatmap target assigner for CenterNet

--
262431036  by Zhichao Lu:

    Adding add_eval_dict to allow for evaluation on model_v2

--
262035351  by ronnyvotel:

    Removing any non-Tensor predictions from the third stage of Mask R-CNN.

--
261953416  by Zhichao Lu:

    Internal change.

--
261834966  by Zhichao Lu:

    Fix the NMS OOM issue on TPU by forcing NMS to run outside of TPU.

--
261775941  by Zhichao Lu:

    Make Keras InputLayer compatible with both TF 1.x and TF 2.0.

--
261775633  by Zhichao Lu:

    Visualize additional channels with ground-truth bounding boxes.

--
261768117  by lzc:

    Internal change.

--
261766773  by ronnyvotel:

    Exposing `return_raw_detections_during_predict` in Faster R-CNN Proto.

--
260975089  by ronnyvotel:

    Moving calculation of batched prediction tensor names after all tensors in prediction dictionary are created.

--
259816913  by ronnyvotel:

    Adding raw detection boxes and feature map indices to SSD

--
259791955  by Zhichao Lu:

    Added a flag to control the use partitioned_non_max_suppression.

--
259580475  by Zhichao Lu:

    Tweak quantization-aware training re-writer to support NasFpn model architecture.

--
259579943  by rathodv:

    Add a meta target assigner proto and builders in OD API.

--
259577741  by Zhichao Lu:

    Internal change.

--
259366315  by lzc:

    Internal change.

--
259344310  by ronnyvotel:

    Updating faster rcnn so that raw_detection_boxes from predict() are in normalized coordinates.

--
259338670  by Zhichao Lu:

    Add support for use_native_resize_op to more feature extractors. Use dynamic shapes when static shapes are not available.

--
259083543  by ronnyvotel:

    Updating/fixing documentation.

--
259078937  by rathodv:

    Add prediction fields for tensors returned from detection_model.predict.

--
259044601  by Zhichao Lu:

    Add protocol buffer and builders for temperature scaling calibration.

--
259036770  by lzc:

    Internal changes.

--
259006223  by ronnyvotel:

    Adding detection anchor indices to Faster R-CNN Config. This is useful when one wishes to associate final detections and the anchors (or pre-nms boxes) from which they originated.

--
258872501  by Zhichao Lu:

    Run the training pipeline of ssd + resnet_v1_50 + fpn with a checkpoint.

--
258840686  by ronnyvotel:

    Adding standard outputs to DetectionModel.predict(). This CL only updates Faster R-CNN. Other meta architectures will be updated in future CLs.

--
258672969  by lzc:

    Internal change.

--
258649494  by lzc:

    Internal changes.

--
258630321  by ronnyvotel:

    Fixing documentation in shape_utils.flatten_dimensions().

--
258468145  by Zhichao Lu:

    Add additional output tensors parameter to Postprocess op.

--
258099219  by Zhichao Lu:

    Internal changes

--

PiperOrigin-RevId: 274959989
parent 9aed0ffb
......@@ -545,6 +545,7 @@ def inception_v2(inputs,
return net, end_points
# 1 x 1 x 1024
net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
end_points['PreLogits'] = net
logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='Conv2d_1c_1x1')
if spatial_squeeze:
......
# MobileNetV2
This folder contains building code for MobileNetV2, based on
[MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)
# Performance
## Latency
This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using
TF-Lite on the large core of Pixel 1 phone.
![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png)
## MACs
MACs, also sometimes known as MADDs - the number of multiply-accumulates needed
to compute an inference on a single image is a common metric to measure the efficiency of the model.
Below is the graph comparing V2 vs a few selected networks. The size
of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there
are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers.
![madds_top1_accuracy](madds_top1_accuracy.png)
# Pretrained models
## Imagenet Checkpoints
Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1
---------------------------|---------|---------------|---------|----|-------------
| [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0
| [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0
| [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8
| [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1
| [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2
| [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6
| [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6
| [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8
| [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6
| [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4
| [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9
| [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2
| [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7
| [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1
| [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9
| [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9
| [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4
| [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7
| [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6
| [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5
| [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9
| [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5
# Training
The numbers above can be reproduced using slim's `train_image_classifier`.
Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU.
If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and
num_epochs_per_decay both need to be adjusted depending on how many GPUs are being
used due to slim's internal averaging.
# MobilenNet
This folder contains building code for
[MobileNetV2](https://arxiv.org/abs/1801.04381) and
[MobilenetV3](https://arxiv.org/abs/1905.02244) networks. The architectural
definition for each model is located in [mobilenet_v2.py](mobilenet_v2.py) and
[mobilenet_v3.py](mobilenet_v3.py) respectively.
For MobilenetV1 please refer to this [page](../mobilenet_v1.md)
## Performance
### Mobilenet V3 latency
This is the timing of [MobileNetV2] vs [MobileNetV3] using TF-Lite on the large
core of Pixel 1 phone.
![Mobilenet V2 and V3 Latency for Pixel 1.png](g3doc/latency_pixel1.png)
### MACs
MACs, also sometimes known as MADDs - the number of multiply-accumulates needed
to compute an inference on a single image is a common metric to measure the
efficiency of the model. Full size Mobilenet V3 on image size 224 uses ~215
Million MADDs (MMadds) while achieving accuracy 75.1%, while Mobilenet V2 uses
~300MMadds and achieving accuracy 72%. By comparison ResNet-50 uses
approximately 3500 MMAdds while achieving 76% accuracy.
Below is the graph comparing Mobilenets and a few selected networks. The size of
each blob represents the number of parameters. Note for
[ShuffleNet](https://arxiv.org/abs/1707.01083) there are no published size
numbers. We estimate it to be comparable to MobileNetV2 numbers.
![madds_top1_accuracy](g3doc/madds_top1_accuracy.png)
## Pretrained models
### Mobilenet V3 Imagenet Checkpoints
All mobilenet V3 checkpoints were trained with image resolution 224x224. All
phone latencies are in milliseconds, measured on large core. In addition to
large and small models this page also contains so-called minimalistic models,
these models have the same per-layer dimensions characteristic as MobilenetV3
however, they don't utilize any of the advanced blocks (squeeze-and-excite
units, hard-swish, and 5x5 convolutions). While these models are less efficient
on CPU, we find that they are much more performant on GPU/DSP/EdgeTpu.
| Imagenet Checkpoint | MACs (M) | Params (M) | Top1 | Pixel 1 | Pixel 2 | Pixel 3 |
| ------------------ | -------- | ---------- | ---- | ------- | ------- | ------- |
| [Large dm=1 (float)] | 217 | 5.4 | 75.2 | 51.2 | 61 | 44 |
| [Large dm=1 (8-bit)] | 217 | 5.4 | 73.9 | 44 | 42.5 | 32 |
| [Large dm=0.75 (float)] | 155 | 4.0 | 73.3 | 39.8 | 48 | 34 |
| [Small dm=1 (float)] | 66 | 2.9 | 67.5 | 15.8 | 19.4 | 14.4 |
| [Small dm=1 (8-bit)] | 66 | 2.9 | 64.9 | 15.5 | 15 | 10.7 |
| [Small dm=0.75 (float)] | 44 | 2.4 | 65.4 | 12.8 | 15.9 | 11.6 |
#### Minimalistic checkpoints:
| Imagenet Checkpoint | MACs (M) | Params (M) | Top1 | Pixel 1 | Pixel 2 | Pixel 3 |
| -------------- | -------- | ---------- | ---- | ------- | ------- | ------- |
| [Large minimalistic (float)] | 209 | 3.9 | 72.3 | 44.1 | 51 | 35 |
| [Large minimalistic (8-bit)][lm8] | 209 | 3.9 | 71.3 | 37 | 35 | 27 |
| [Small minimalistic (float)] | 65 | 2.0 | 61.9 | 12.2 | 15.1 | 11 |
[Small minimalistic (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-small-minimalistic_224_1.0_float.tgz
[Large minimalistic (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large-minimalistic_224_1.0_float.tgz
[lm8]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large-minimalistic_224_1.0_uint8.tgz
[Large dm=1 (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz
[Small dm=1 (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-small_224_1.0_float.tgz
[Large dm=1 (8-bit)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz
[Small dm=1 (8-bit)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-small_224_1.0_uint8.tgz
[Large dm=0.75 (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_0.75_float.tgz
[Small dm=0.75 (float)]: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-small_224_0.75_float.tgz
### Mobilenet V2 Imagenet Checkpoints
Classification Checkpoint | MACs (M) | Parameters (M) | Top 1 Accuracy | Top 5 Accuracy | Mobile CPU (ms) Pixel 1
---------------------------------------------------------------------------------------------------------- | -------- | -------------- | -------------- | -------------- | -----------------------
[mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0
[mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0
[mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8
[mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1
[mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2
[mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6
[mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6
[mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8
[mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6
[mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4
[mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9
[mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2
[mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7
[mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1
[mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9
[mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9
[mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4
[mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7
[mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6
[mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5
[mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9
[mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5
## Training
### V3
TODO: Add V3 hyperparameters
### V2
The numbers above can be reproduced using slim's
[`train_image_classifier`](https://github.com/tensorflow/models/blob/master/research/slim/README.md#training-a-model-from-scratch).
Below is the set of parameters that achieves 72.0% for full size MobileNetV2,
after about 700K when trained on 8 GPU. If trained on a single GPU the full
convergence is after 5.5M steps. Also note that learning rate and
num_epochs_per_decay both need to be adjusted depending on how many GPUs are
being used due to slim's internal averaging.
```bash
--model_name="mobilenet_v2"
......@@ -68,6 +130,9 @@ used due to slim's internal averaging.
# Example
See this [ipython notebook](mobilenet_example.ipynb) or open and run the network
directly in
[Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb).
See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb).
[MobilenetV2]: https://arxiv.org/abs/1801.04381
[MobilenetV3]: https://arxiv.org/abs/1905.02244
......@@ -159,6 +159,50 @@ def expand_input_by_factor(n, divisible_by=8):
return lambda num_inputs, **_: _make_divisible(num_inputs * n, divisible_by)
def split_conv(input_tensor,
num_outputs,
num_ways,
scope,
divisible_by=8,
**kwargs):
"""Creates a split convolution.
Split convolution splits the input and output into
'num_blocks' blocks of approximately the same size each,
and only connects $i$-th input to $i$ output.
Args:
input_tensor: input tensor
num_outputs: number of output filters
num_ways: num blocks to split by.
scope: scope for all the operators.
divisible_by: make sure that every part is divisiable by this.
**kwargs: will be passed directly into conv2d operator
Returns:
tensor
"""
b = input_tensor.get_shape().as_list()[3]
if num_ways == 1 or min(b // num_ways,
num_outputs // num_ways) < divisible_by:
# Don't do any splitting if we end up with less than 8 filters
# on either side.
return slim.conv2d(input_tensor, num_outputs, [1, 1], scope=scope, **kwargs)
outs = []
input_splits = _split_divisible(b, num_ways, divisible_by=divisible_by)
output_splits = _split_divisible(
num_outputs, num_ways, divisible_by=divisible_by)
inputs = tf.split(input_tensor, input_splits, axis=3, name='split_' + scope)
base = scope
for i, (input_tensor, out_size) in enumerate(zip(inputs, output_splits)):
scope = base + '_part_%d' % (i,)
n = slim.conv2d(input_tensor, out_size, [1, 1], scope=scope, **kwargs)
n = tf.identity(n, scope + '_output')
outs.append(n)
return tf.concat(outs, 3, name=scope + '_concat')
@slim.add_arg_scope
def expanded_conv(input_tensor,
num_outputs,
......@@ -168,7 +212,6 @@ def expanded_conv(input_tensor,
kernel_size=(3, 3),
residual=True,
normalizer_fn=None,
project_activation_fn=tf.identity,
split_projection=1,
split_expansion=1,
split_divisible_by=8,
......@@ -178,6 +221,12 @@ def expanded_conv(input_tensor,
endpoints=None,
use_explicit_padding=False,
padding='SAME',
inner_activation_fn=None,
depthwise_activation_fn=None,
project_activation_fn=tf.identity,
depthwise_fn=slim.separable_conv2d,
expansion_fn=split_conv,
projection_fn=split_conv,
scope=None):
"""Depthwise Convolution Block with expansion.
......@@ -197,7 +246,6 @@ def expanded_conv(input_tensor,
residual: whether to include residual connection between input
and output.
normalizer_fn: batchnorm or otherwise
project_activation_fn: activation function for the project layer
split_projection: how many ways to split projection operator
(that is conv expansion->bottleneck)
split_expansion: how many ways to split expansion op
......@@ -220,6 +268,20 @@ def expanded_conv(input_tensor,
inputs so that the output dimensions are the same as if 'SAME' padding
were used.
padding: Padding type to use if `use_explicit_padding` is not set.
inner_activation_fn: activation function to use in all inner convolutions.
If none, will rely on slim default scopes.
depthwise_activation_fn: activation function to use for deptwhise only.
If not provided will rely on slim default scopes. If both
inner_activation_fn and depthwise_activation_fn are provided,
depthwise_activation_fn takes precedence over inner_activation_fn.
project_activation_fn: activation function for the project layer.
(note this layer is not affected by inner_activation_fn)
depthwise_fn: Depthwise convolution function.
expansion_fn: Expansion convolution function. If use custom function then
"split_expansion" and "split_divisible_by" will be ignored.
projection_fn: Projection convolution function. If use custom function then
"split_projection" and "split_divisible_by" will be ignored.
scope: optional scope.
Returns:
......@@ -228,8 +290,18 @@ def expanded_conv(input_tensor,
Raises:
TypeError: on inval
"""
conv_defaults = {}
dw_defaults = {}
if inner_activation_fn is not None:
conv_defaults['activation_fn'] = inner_activation_fn
dw_defaults['activation_fn'] = inner_activation_fn
if depthwise_activation_fn is not None:
dw_defaults['activation_fn'] = depthwise_activation_fn
# pylint: disable=g-backslash-continuation
with tf.variable_scope(scope, default_name='expanded_conv') as s, \
tf.name_scope(s.original_name_scope):
tf.name_scope(s.original_name_scope), \
slim.arg_scope((slim.conv2d,), **conv_defaults), \
slim.arg_scope((slim.separable_conv2d,), **dw_defaults):
prev_depth = input_tensor.get_shape().as_list()[3]
if depthwise_location not in [None, 'input', 'output', 'expansion']:
raise TypeError('%r is unknown value for depthwise_location' %
......@@ -240,7 +312,7 @@ def expanded_conv(input_tensor,
'"SAME" padding.')
padding = 'VALID'
depthwise_func = functools.partial(
slim.separable_conv2d,
depthwise_fn,
num_outputs=None,
kernel_size=kernel_size,
depth_multiplier=depthwise_channel_multiplier,
......@@ -258,6 +330,9 @@ def expanded_conv(input_tensor,
if use_explicit_padding:
net = _fixed_padding(net, kernel_size, rate)
net = depthwise_func(net, activation_fn=None)
net = tf.identity(net, name='depthwise_output')
if endpoints is not None:
endpoints['depthwise_output'] = net
if callable(expansion_size):
inner_size = expansion_size(num_inputs=prev_depth)
......@@ -265,37 +340,43 @@ def expanded_conv(input_tensor,
inner_size = expansion_size
if inner_size > net.shape[3]:
net = split_conv(
if expansion_fn == split_conv:
expansion_fn = functools.partial(
expansion_fn,
num_ways=split_expansion,
divisible_by=split_divisible_by,
stride=1)
net = expansion_fn(
net,
inner_size,
num_ways=split_expansion,
scope='expand',
divisible_by=split_divisible_by,
stride=1,
normalizer_fn=normalizer_fn)
net = tf.identity(net, 'expansion_output')
if endpoints is not None:
endpoints['expansion_output'] = net
if endpoints is not None:
endpoints['expansion_output'] = net
if depthwise_location == 'expansion':
if use_explicit_padding:
net = _fixed_padding(net, kernel_size, rate)
net = depthwise_func(net)
net = tf.identity(net, name='depthwise_output')
if endpoints is not None:
endpoints['depthwise_output'] = net
net = tf.identity(net, name='depthwise_output')
if endpoints is not None:
endpoints['depthwise_output'] = net
if expansion_transform:
net = expansion_transform(expansion_tensor=net, input_tensor=input_tensor)
# Note in contrast with expansion, we always have
# projection to produce the desired output size.
net = split_conv(
if projection_fn == split_conv:
projection_fn = functools.partial(
projection_fn,
num_ways=split_projection,
divisible_by=split_divisible_by,
stride=1)
net = projection_fn(
net,
num_outputs,
num_ways=split_projection,
stride=1,
scope='project',
divisible_by=split_divisible_by,
normalizer_fn=normalizer_fn,
activation_fn=project_activation_fn)
if endpoints is not None:
......@@ -304,6 +385,9 @@ def expanded_conv(input_tensor,
if use_explicit_padding:
net = _fixed_padding(net, kernel_size, rate)
net = depthwise_func(net, activation_fn=None)
net = tf.identity(net, name='depthwise_output')
if endpoints is not None:
endpoints['depthwise_output'] = net
if callable(residual): # custom residual
net = residual(input_tensor=input_tensor, output_tensor=net)
......@@ -318,45 +402,65 @@ def expanded_conv(input_tensor,
return tf.identity(net, name='output')
def split_conv(input_tensor,
num_outputs,
num_ways,
scope,
divisible_by=8,
**kwargs):
"""Creates a split convolution.
Split convolution splits the input and output into
'num_blocks' blocks of approximately the same size each,
and only connects $i$-th input to $i$ output.
@slim.add_arg_scope
def squeeze_excite(input_tensor,
divisible_by=8,
squeeze_factor=3,
inner_activation_fn=tf.nn.relu,
gating_fn=tf.sigmoid,
squeeze_input_tensor=None,
pool=None):
"""Squeeze excite block for Mobilenet V3.
Args:
input_tensor: input tensor
num_outputs: number of output filters
num_ways: num blocks to split by.
scope: scope for all the operators.
divisible_by: make sure that every part is divisiable by this.
**kwargs: will be passed directly into conv2d operator
input_tensor: input tensor to apply SE block to.
divisible_by: ensures all inner dimensions are divisible by this number.
squeeze_factor: the factor of squeezing in the inner fully connected layer
inner_activation_fn: non-linearity to be used in inner layer.
gating_fn: non-linearity to be used for final gating function
squeeze_input_tensor: custom tensor to use for computing gating activation.
If provided the result will be input_tensor * SE(squeeze_input_tensor)
instead of input_tensor * SE(input_tensor).
pool: if number is provided will average pool with that kernel size
to compute inner tensor, followed by bilinear upsampling.
Returns:
tensor
Gated input_tensor. (e.g. X * SE(X))
"""
b = input_tensor.get_shape().as_list()[3]
if num_ways == 1 or min(b // num_ways,
num_outputs // num_ways) < divisible_by:
# Don't do any splitting if we end up with less than 8 filters
# on either side.
return slim.conv2d(input_tensor, num_outputs, [1, 1], scope=scope, **kwargs)
with tf.variable_scope('squeeze_excite'):
if squeeze_input_tensor is None:
squeeze_input_tensor = input_tensor
input_size = input_tensor.shape.as_list()[1:3]
pool_height, pool_width = squeeze_input_tensor.shape.as_list()[1:3]
stride = 1
if pool is not None and pool_height >= pool:
pool_height, pool_width, stride = pool, pool, pool
input_channels = squeeze_input_tensor.shape.as_list()[3]
output_channels = input_tensor.shape.as_list()[3]
squeeze_channels = _make_divisible(
input_channels / squeeze_factor, divisor=divisible_by)
pooled = tf.nn.avg_pool(squeeze_input_tensor,
(1, pool_height, pool_width, 1),
strides=(1, stride, stride, 1),
padding='VALID')
squeeze = slim.conv2d(
pooled,
kernel_size=(1, 1),
num_outputs=squeeze_channels,
normalizer_fn=None,
activation_fn=inner_activation_fn)
excite_outputs = output_channels
excite = slim.conv2d(squeeze, num_outputs=excite_outputs,
kernel_size=[1, 1],
normalizer_fn=None,
activation_fn=gating_fn)
if pool is not None:
# Note: As of 03/20/2019 only BILINEAR (the default) with
# align_corners=True has gradients implemented in TPU.
excite = tf.image.resize_images(
excite, input_size,
align_corners=True)
result = input_tensor * excite
return result
outs = []
input_splits = _split_divisible(b, num_ways, divisible_by=divisible_by)
output_splits = _split_divisible(
num_outputs, num_ways, divisible_by=divisible_by)
inputs = tf.split(input_tensor, input_splits, axis=3, name='split_' + scope)
base = scope
for i, (input_tensor, out_size) in enumerate(zip(inputs, output_splits)):
scope = base + '_part_%d' % (i,)
n = slim.conv2d(input_tensor, out_size, [1, 1], scope=scope, **kwargs)
n = tf.identity(n, scope + '_output')
outs.append(n)
return tf.concat(outs, 3, name=scope + '_concat')
......@@ -66,7 +66,7 @@ def _make_divisible(v, divisor, min_value=None):
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
return int(new_v)
@contextlib.contextmanager
......
......@@ -81,6 +81,25 @@ V2_DEF = dict(
)
# pyformat: enable
# Mobilenet v2 Definition with group normalization.
V2_DEF_GROUP_NORM = copy.deepcopy(V2_DEF)
V2_DEF_GROUP_NORM['defaults'] = {
(tf.contrib.slim.conv2d, tf.contrib.slim.fully_connected,
tf.contrib.slim.separable_conv2d): {
'normalizer_fn': tf.contrib.layers.group_norm, # pylint: disable=C0330
'activation_fn': tf.nn.relu6, # pylint: disable=C0330
}, # pylint: disable=C0330
(ops.expanded_conv,): {
'expansion_size': ops.expand_input_by_factor(6),
'split_expansion': 1,
'normalizer_fn': tf.contrib.layers.group_norm,
'residual': True
},
(tf.contrib.slim.conv2d, tf.contrib.slim.separable_conv2d): {
'padding': 'SAME'
}
}
@slim.add_arg_scope
def mobilenet(input_tensor,
......@@ -189,6 +208,19 @@ def mobilenet_base(input_tensor, depth_multiplier=1.0, **kwargs):
base_only=True, **kwargs)
@slim.add_arg_scope
def mobilenet_base_group_norm(input_tensor, depth_multiplier=1.0, **kwargs):
"""Creates base of the mobilenet (no pooling and no logits) ."""
kwargs['conv_defs'] = V2_DEF_GROUP_NORM
kwargs['conv_defs']['defaults'].update({
(tf.contrib.layers.group_norm,): {
'groups': kwargs.pop('groups', 8)
}
})
return mobilenet(
input_tensor, depth_multiplier=depth_multiplier, base_only=True, **kwargs)
def training_scope(**kwargs):
"""Defines MobilenetV2 training scope.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mobilenet V3 conv defs and helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import functools
import numpy as np
import tensorflow as tf
from nets.mobilenet import conv_blocks as ops
from nets.mobilenet import mobilenet as lib
slim = tf.contrib.slim
op = lib.op
expand_input = ops.expand_input_by_factor
# Squeeze Excite with all parameters filled-in, we use hard-sigmoid
# for gating function and relu for inner activation function.
squeeze_excite = functools.partial(
ops.squeeze_excite, squeeze_factor=4,
inner_activation_fn=tf.nn.relu,
gating_fn=lambda x: tf.nn.relu6(x+3)*0.16667)
# Wrap squeeze excite op as expansion_transform that takes
# both expansion and input tensor.
_se4 = lambda expansion_tensor, input_tensor: squeeze_excite(expansion_tensor)
def hard_swish(x):
with tf.name_scope('hard_swish'):
return x * tf.nn.relu6(x + np.float32(3)) * np.float32(1. / 6.)
def reduce_to_1x1(input_tensor, default_size=7, **kwargs):
h, w = input_tensor.shape.as_list()[1:3]
if h is not None and w == h:
k = [h, h]
else:
k = [default_size, default_size]
return slim.avg_pool2d(input_tensor, kernel_size=k, **kwargs)
def mbv3_op(ef, n, k, s=1, act=tf.nn.relu, se=None):
"""Defines a single Mobilenet V3 convolution block.
Args:
ef: expansion factor
n: number of output channels
k: stride of depthwise
s: stride
act: activation function in inner layers
se: squeeze excite function.
Returns:
An object (lib._Op) for inserting in conv_def, representing this operation.
"""
return op(ops.expanded_conv, expansion_size=expand_input(ef),
kernel_size=(k, k), stride=s, num_outputs=n,
inner_activation_fn=act,
expansion_transform=se)
mbv3_op_se = functools.partial(mbv3_op, se=_se4)
DEFAULTS = {
(ops.expanded_conv,):
dict(
normalizer_fn=slim.batch_norm,
residual=True),
(slim.conv2d, slim.fully_connected, slim.separable_conv2d): {
'normalizer_fn': slim.batch_norm,
'activation_fn': tf.nn.relu,
},
(slim.batch_norm,): {
'center': True,
'scale': True
},
}
# Compatible checkpoint: http://mldash/5511169891790690458#scalars
V3_LARGE = dict(
defaults=dict(DEFAULTS),
spec=([
# stage 1
op(slim.conv2d, stride=2, num_outputs=16, kernel_size=(3, 3),
activation_fn=hard_swish),
mbv3_op(ef=1, n=16, k=3),
mbv3_op(ef=4, n=24, k=3, s=2),
mbv3_op(ef=3, n=24, k=3, s=1),
mbv3_op_se(ef=3, n=40, k=5, s=2),
mbv3_op_se(ef=3, n=40, k=5, s=1),
mbv3_op_se(ef=3, n=40, k=5, s=1),
mbv3_op(ef=6, n=80, k=3, s=2, act=hard_swish),
mbv3_op(ef=2.5, n=80, k=3, s=1, act=hard_swish),
mbv3_op(ef=184/80., n=80, k=3, s=1, act=hard_swish),
mbv3_op(ef=184/80., n=80, k=3, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=112, k=3, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=112, k=3, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=160, k=5, s=2, act=hard_swish),
mbv3_op_se(ef=6, n=160, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=160, k=5, s=1, act=hard_swish),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=960,
activation_fn=hard_swish),
op(reduce_to_1x1, default_size=7, stride=1, padding='VALID'),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=1280,
normalizer_fn=None, activation_fn=hard_swish)
]))
# 72.2% accuracy.
V3_LARGE_MINIMALISTIC = dict(
defaults=dict(DEFAULTS),
spec=([
# stage 1
op(slim.conv2d, stride=2, num_outputs=16, kernel_size=(3, 3)),
mbv3_op(ef=1, n=16, k=3),
mbv3_op(ef=4, n=24, k=3, s=2),
mbv3_op(ef=3, n=24, k=3, s=1),
mbv3_op(ef=3, n=40, k=3, s=2),
mbv3_op(ef=3, n=40, k=3, s=1),
mbv3_op(ef=3, n=40, k=3, s=1),
mbv3_op(ef=6, n=80, k=3, s=2),
mbv3_op(ef=2.5, n=80, k=3, s=1),
mbv3_op(ef=184 / 80., n=80, k=3, s=1),
mbv3_op(ef=184 / 80., n=80, k=3, s=1),
mbv3_op(ef=6, n=112, k=3, s=1),
mbv3_op(ef=6, n=112, k=3, s=1),
mbv3_op(ef=6, n=160, k=3, s=2),
mbv3_op(ef=6, n=160, k=3, s=1),
mbv3_op(ef=6, n=160, k=3, s=1),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=960),
op(reduce_to_1x1, default_size=7, stride=1, padding='VALID'),
op(slim.conv2d,
stride=1,
kernel_size=[1, 1],
num_outputs=1280,
normalizer_fn=None)
]))
# Compatible run: http://mldash/2023283040014348118#scalars
V3_SMALL = dict(
defaults=dict(DEFAULTS),
spec=([
# stage 1
op(slim.conv2d, stride=2, num_outputs=16, kernel_size=(3, 3),
activation_fn=hard_swish),
mbv3_op_se(ef=1, n=16, k=3, s=2),
mbv3_op(ef=72./16, n=24, k=3, s=2),
mbv3_op(ef=(88./24), n=24, k=3, s=1),
mbv3_op_se(ef=4, n=40, k=5, s=2, act=hard_swish),
mbv3_op_se(ef=6, n=40, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=40, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=3, n=48, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=3, n=48, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=96, k=5, s=2, act=hard_swish),
mbv3_op_se(ef=6, n=96, k=5, s=1, act=hard_swish),
mbv3_op_se(ef=6, n=96, k=5, s=1, act=hard_swish),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=576,
activation_fn=hard_swish),
op(reduce_to_1x1, default_size=7, stride=1, padding='VALID'),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=1024,
normalizer_fn=None, activation_fn=hard_swish)
]))
# 62% accuracy.
V3_SMALL_MINIMALISTIC = dict(
defaults=dict(DEFAULTS),
spec=([
# stage 1
op(slim.conv2d, stride=2, num_outputs=16, kernel_size=(3, 3)),
mbv3_op(ef=1, n=16, k=3, s=2),
mbv3_op(ef=72. / 16, n=24, k=3, s=2),
mbv3_op(ef=(88. / 24), n=24, k=3, s=1),
mbv3_op(ef=4, n=40, k=3, s=2),
mbv3_op(ef=6, n=40, k=3, s=1),
mbv3_op(ef=6, n=40, k=3, s=1),
mbv3_op(ef=3, n=48, k=3, s=1),
mbv3_op(ef=3, n=48, k=3, s=1),
mbv3_op(ef=6, n=96, k=3, s=2),
mbv3_op(ef=6, n=96, k=3, s=1),
mbv3_op(ef=6, n=96, k=3, s=1),
op(slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=576),
op(reduce_to_1x1, default_size=7, stride=1, padding='VALID'),
op(slim.conv2d,
stride=1,
kernel_size=[1, 1],
num_outputs=1024,
normalizer_fn=None)
]))
@slim.add_arg_scope
def mobilenet(input_tensor,
num_classes=1001,
depth_multiplier=1.0,
scope='MobilenetV3',
conv_defs=None,
finegrain_classification_mode=False,
**kwargs):
"""Creates mobilenet V3 network.
Inference mode is created by default. To create training use training_scope
below.
with tf.contrib.slim.arg_scope(mobilenet_v3.training_scope()):
logits, endpoints = mobilenet_v3.mobilenet(input_tensor)
Args:
input_tensor: The input tensor
num_classes: number of classes
depth_multiplier: The multiplier applied to scale number of
channels in each layer.
scope: Scope of the operator
conv_defs: Which version to create. Could be large/small or
any conv_def (see mobilenet_v3.py for examples).
finegrain_classification_mode: When set to True, the model
will keep the last layer large even for small multipliers. Following
https://arxiv.org/abs/1801.04381
it improves performance for ImageNet-type of problems.
*Note* ignored if final_endpoint makes the builder exit earlier.
**kwargs: passed directly to mobilenet.mobilenet:
prediction_fn- what prediction function to use.
reuse-: whether to reuse variables (if reuse set to true, scope
must be given).
Returns:
logits/endpoints pair
Raises:
ValueError: On invalid arguments
"""
if conv_defs is None:
conv_defs = V3_LARGE
if 'multiplier' in kwargs:
raise ValueError('mobilenetv2 doesn\'t support generic '
'multiplier parameter use "depth_multiplier" instead.')
if finegrain_classification_mode:
conv_defs = copy.deepcopy(conv_defs)
conv_defs['spec'][-1] = conv_defs['spec'][-1]._replace(
multiplier_func=lambda params, multiplier: params)
depth_args = {}
with slim.arg_scope((lib.depth_multiplier,), **depth_args):
return lib.mobilenet(
input_tensor,
num_classes=num_classes,
conv_defs=conv_defs,
scope=scope,
multiplier=depth_multiplier,
**kwargs)
mobilenet.default_image_size = 224
training_scope = lib.training_scope
@slim.add_arg_scope
def mobilenet_base(input_tensor, depth_multiplier=1.0, **kwargs):
"""Creates base of the mobilenet (no pooling and no logits) ."""
return mobilenet(
input_tensor, depth_multiplier=depth_multiplier, base_only=True, **kwargs)
def wrapped_partial(func, *args, **kwargs):
partial_func = functools.partial(func, *args, **kwargs)
functools.update_wrapper(partial_func, func)
return partial_func
large = wrapped_partial(mobilenet, conv_defs=V3_LARGE)
small = wrapped_partial(mobilenet, conv_defs=V3_SMALL)
# Minimalistic model that does not have Squeeze Excite blocks,
# Hardswish, or 5x5 depthwise convolution.
# This makes the model very friendly for a wide range of hardware
large_minimalistic = wrapped_partial(mobilenet, conv_defs=V3_LARGE_MINIMALISTIC)
small_minimalistic = wrapped_partial(mobilenet, conv_defs=V3_SMALL_MINIMALISTIC)
def _reduce_consecutive_layers(conv_defs, start_id, end_id, multiplier=0.5):
"""Reduce the outputs of consecutive layers with multiplier.
Args:
conv_defs: Mobilenet conv_defs.
start_id: 0-based index of the starting conv_def to be reduced.
end_id: 0-based index of the last conv_def to be reduced.
multiplier: The multiplier by which to reduce the conv_defs.
Returns:
Mobilenet conv_defs where the output sizes from layers [start_id, end_id],
inclusive, are reduced by multiplier.
Raises:
ValueError if any layer to be reduced does not have the 'num_outputs'
attribute.
"""
defs = copy.deepcopy(conv_defs)
for d in defs['spec'][start_id:end_id+1]:
d.params.update({
'num_outputs': np.int(np.round(d.params['num_outputs'] * multiplier))
})
return defs
V3_LARGE_DETECTION = _reduce_consecutive_layers(V3_LARGE, 13, 16)
V3_SMALL_DETECTION = _reduce_consecutive_layers(V3_SMALL, 9, 12)
__all__ = ['training_scope', 'mobilenet', 'V3_LARGE', 'V3_SMALL', 'large',
'small', 'V3_LARGE_DETECTION', 'V3_SMALL_DETECTION']
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google3.third_party.tensorflow_models.slim.nets.mobilenet.mobilenet_v3."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import tensorflow as tf
from nets.mobilenet import mobilenet_v3
class MobilenetV3Test(absltest.TestCase):
def setUp(self):
super(MobilenetV3Test, self).setUp()
tf.reset_default_graph()
def testMobilenetV3Large(self):
logits, endpoints = mobilenet_v3.mobilenet(
tf.placeholder(tf.float32, (1, 224, 224, 3)))
self.assertEqual(endpoints['layer_19'].shape, [1, 1, 1, 1280])
self.assertEqual(logits.shape, [1, 1001])
def testMobilenetV3Small(self):
_, endpoints = mobilenet_v3.mobilenet(
tf.placeholder(tf.float32, (1, 224, 224, 3)),
conv_defs=mobilenet_v3.V3_SMALL)
self.assertEqual(endpoints['layer_15'].shape, [1, 1, 1, 1024])
def testMobilenetV3BaseOnly(self):
result, endpoints = mobilenet_v3.mobilenet(
tf.placeholder(tf.float32, (1, 224, 224, 3)),
conv_defs=mobilenet_v3.V3_LARGE,
base_only=True,
final_endpoint='layer_17')
# Get the latest layer before average pool.
self.assertEqual(endpoints['layer_17'].shape, [1, 7, 7, 960])
self.assertEqual(result, endpoints['layer_17'])
if __name__ == '__main__':
absltest.main()
# Mobilenet_v2
For Mobilenet V2 see this file [mobilenet/README.md](mobilenet/README.md).
# MobilenetV2 and above
For MobilenetV2+ see this file [mobilenet/README.md](mobilenet/README_md)
# MobileNet_v1
# MobileNetV1
[MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
......
......@@ -68,6 +68,7 @@ def vgg_a(inputs,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
reuse=None,
scope='vgg_a',
fc_conv_padding='VALID',
global_pool=False):
......@@ -85,6 +86,8 @@ def vgg_a(inputs,
layers during training.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional scope for the variables.
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
......@@ -101,7 +104,7 @@ def vgg_a(inputs,
or the input to the logits layer (if num_classes is 0 or None).
end_points: a dict of tensors with intermediate activations.
"""
with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
with tf.variable_scope(scope, 'vgg_a', [inputs], reuse=reuse) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.max_pool2d],
......@@ -146,6 +149,7 @@ def vgg_16(inputs,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
reuse=None,
scope='vgg_16',
fc_conv_padding='VALID',
global_pool=False):
......@@ -163,6 +167,8 @@ def vgg_16(inputs,
layers during training.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional scope for the variables.
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
......@@ -179,7 +185,7 @@ def vgg_16(inputs,
or the input to the logits layer (if num_classes is 0 or None).
end_points: a dict of tensors with intermediate activations.
"""
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
......@@ -224,6 +230,7 @@ def vgg_19(inputs,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
reuse=None,
scope='vgg_19',
fc_conv_padding='VALID',
global_pool=False):
......@@ -241,6 +248,8 @@ def vgg_19(inputs,
layers during training.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional scope for the variables.
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
......@@ -258,7 +267,7 @@ def vgg_19(inputs,
None).
end_points: a dict of tensors with intermediate activations.
"""
with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment