Unverified Commit 194a0846 authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Add RegNet Architecture in TorchVision (#4403)

* initial code

* add SqueezeExcitation

* initial code

* add SqueezeExcitation

* add SqueezeExcitation

* regnet blocks, stems and model definition

* nit

* add fc layer

* use Callable instead of Enum for block, stem and activation

* add regnet_x and regnet_y model build functions, add docs

* remove unused depth

* use BN/activation constructor and ConvBNActivation

* add expected test pkl files

* allow custom activation in SqueezeExcitation

* use ReLU as the default activation

* initial code

* add SqueezeExcitation

* initial code

* add SqueezeExcitation

* add SqueezeExcitation

* regnet blocks, stems and model definition

* nit

* add fc layer

* use Callable instead of Enum for block, stem and activation

* add regnet_x and regnet_y model build functions, add docs

* remove unused depth

* use BN/activation constructor and ConvBNActivation

* reuse SqueezeExcitation from efficientnet

* refactor RegNetParams into BlockParams

* use nn.init, replace np with torch

* update README

* construct model with stem, block, classifier instances

* Revert "construct model with stem, block, classifier instances"

This reverts commit 850f5f3ed01a2a9b36fcbf8405afd6e41d2e58ef.

* remove unused blocks

* support scaled model

* fuse into ConvBNActivation

* make reset_parameters private

* fix type errors

* fix for unit test

* add pretrained weights for 6 variant models, update docs
parent c4dc3e02
...@@ -37,6 +37,7 @@ architectures for image classification: ...@@ -37,6 +37,7 @@ architectures for image classification:
- `Wide ResNet`_ - `Wide ResNet`_
- `MNASNet`_ - `MNASNet`_
- `EfficientNet`_ - `EfficientNet`_
- `RegNet`_
You can construct a model with random weights by calling its constructor: You can construct a model with random weights by calling its constructor:
...@@ -65,6 +66,20 @@ You can construct a model with random weights by calling its constructor: ...@@ -65,6 +66,20 @@ You can construct a model with random weights by calling its constructor:
efficientnet_b5 = models.efficientnet_b5() efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6() efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7() efficientnet_b7 = models.efficientnet_b7()
regnet_y_400mf = models.regnet_y_400mf()
regnet_y_800mf = models.regnet_y_800mf()
regnet_y_1_6gf = models.regnet_y_1_6gf()
regnet_y_3_2gf = models.regnet_y_3_2gf()
regnet_y_8gf = models.regnet_y_8gf()
regnet_y_16gf = models.regnet_y_16gf()
regnet_y_32gf = models.regnet_y_32gf()
regnet_x_400mf = models.regnet_x_400mf()
regnet_x_800mf = models.regnet_x_800mf()
regnet_x_1_6gf = models.regnet_x_1_6gf()
regnet_x_3_2gf = models.regnet_x_3_2gf()
regnet_x_8gf = models.regnet_x_8gf()
regnet_x_16gf = models.regnet_x_16gf()
regnet_x_32gf = models.regnet_x_32gf()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``: These can be constructed by passing ``pretrained=True``:
...@@ -94,6 +109,12 @@ These can be constructed by passing ``pretrained=True``: ...@@ -94,6 +109,12 @@ These can be constructed by passing ``pretrained=True``:
efficientnet_b5 = models.efficientnet_b5(pretrained=True) efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True) efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True) efficientnet_b7 = models.efficientnet_b7(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory. Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
...@@ -188,6 +209,12 @@ EfficientNet-B4 83.384 96.594 ...@@ -188,6 +209,12 @@ EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628 EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916 EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908 EfficientNet-B7 84.122 96.908
regnet_x_400mf 72.834 90.950
regnet_x_800mf 75.190 92.418
regnet_x_8gf 79.324 94.694
regnet_y_400mf 74.024 91.680
regnet_y_800mf 76.420 93.136
regnet_y_8gf 79.966 95.100
================================ ============= ============= ================================ ============= =============
...@@ -204,6 +231,7 @@ EfficientNet-B7 84.122 96.908 ...@@ -204,6 +231,7 @@ EfficientNet-B7 84.122 96.908
.. _ResNeXt: https://arxiv.org/abs/1611.05431 .. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626 .. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946 .. _EfficientNet: https://arxiv.org/abs/1905.11946
.. _RegNet: https://arxiv.org/abs/2003.13678
.. currentmodule:: torchvision.models .. currentmodule:: torchvision.models
...@@ -317,6 +345,24 @@ EfficientNet ...@@ -317,6 +345,24 @@ EfficientNet
.. autofunction:: efficientnet_b6 .. autofunction:: efficientnet_b6
.. autofunction:: efficientnet_b7 .. autofunction:: efficientnet_b7
RegNet
------------
.. autofunction:: regnet_y_400mf
.. autofunction:: regnet_y_800mf
.. autofunction:: regnet_y_1_6gf
.. autofunction:: regnet_y_3_2gf
.. autofunction:: regnet_y_8gf
.. autofunction:: regnet_y_16gf
.. autofunction:: regnet_y_32gf
.. autofunction:: regnet_x_400mf
.. autofunction:: regnet_x_800mf
.. autofunction:: regnet_x_1_6gf
.. autofunction:: regnet_x_3_2gf
.. autofunction:: regnet_x_8gf
.. autofunction:: regnet_x_16gf
.. autofunction:: regnet_x_32gf
Quantized Models Quantized Models
---------------- ----------------
......
...@@ -17,6 +17,10 @@ from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ ...@@ -17,6 +17,10 @@ from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3 mnasnet1_3
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \ from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
from torchvision.models.regnet import regnet_y_400mf, regnet_y_800mf, \
regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, \
regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, \
regnet_x_16gf, regnet_x_32gf
# segmentation # segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
......
...@@ -79,6 +79,36 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht ...@@ -79,6 +79,36 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
### RegNet
#### Small models
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 100 --batch-size 128 --wd 0.00005 --lr=0.8\
--lr-scheduler=cosineannealinglr --lr-warmup-method=linear\
--lr-warmup-epochs=5 --lr-warmup-decay=0.1
```
Here `$MODEL` is one of `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_y_400mf`, `regnet_y_800mf` and `regnet_y_1_6gf`. Please note we used learning rate 0.4 for `regent_y_400mf` to get the same Acc@1 as [the paper)(https://arxiv.org/abs/2003.13678).
### Medium models
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 100 --batch-size 64 --wd 0.00005 --lr=0.4\
--lr-scheduler=cosineannealinglr --lr-warmup-method=linear\
--lr-warmup-epochs=5 --lr-warmup-decay=0.1
```
Here `$MODEL` is one of `regnet_x_3_2gf`, `regnet_x_8gf`, `regnet_x_16gf`, `regnet_y_3_2gf` and `regnet_y_8gf`.
### Large models
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 100 --batch-size 32 --wd 0.00005 --lr=0.2\
--lr-scheduler=cosineannealinglr --lr-warmup-method=linear\
--lr-warmup-epochs=5 --lr-warmup-decay=0.1
```
Here `$MODEL` is one of `regnet_x_32gf`, `regnet_y_16gf` and `regnet_y_32gf`.
## Mixed precision training ## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex). Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -9,6 +9,7 @@ from .mobilenet import * ...@@ -9,6 +9,7 @@ from .mobilenet import *
from .mnasnet import * from .mnasnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .efficientnet import * from .efficientnet import *
from .regnet import *
from . import segmentation from . import segmentation
from . import detection from . import detection
from . import video from . import video
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment