Unverified Commit 69ba6789 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

Add DeepLabV3 + ResNeSt-269 (#263)

parent 17be9e16
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
[![Build Docs](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Build%20Docs/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions) [![Build Docs](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Build%20Docs/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions)
[![Unit Test](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Unit%20Test/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions) [![Unit Test](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Unit%20Test/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=resnest-split-attention-networks)
# PyTorch-Encoding # PyTorch-Encoding
created by [Hang Zhang](http://hangzh.com/) created by [Hang Zhang](http://hangzh.com/)
......
...@@ -71,7 +71,7 @@ Test Pretrained ...@@ -71,7 +71,7 @@ Test Pretrained
- The test script is in the ``experiments/recognition/`` folder. For evaluating the model (using MS), - The test script is in the ``experiments/recognition/`` folder. For evaluating the model (using MS),
for example ``ResNeSt50``:: for example ``ResNeSt50``::
python test.py --dataset imagenet --model-zoo ResNeSt50 --crop-size 224 --eval python verify.py --dataset imagenet --model ResNeSt50 --crop-size 224
Train Your Own Model Train Your Own Model
-------------------- --------------------
...@@ -81,3 +81,5 @@ Train Your Own Model ...@@ -81,3 +81,5 @@ Train Your Own Model
python scripts/prepare_imagenet.py --data-dir ./ python scripts/prepare_imagenet.py --data-dir ./
- The training script is in the ``experiments/recognition/`` folder. Commands for reproducing pre-trained models can be found in the table. - The training script is in the ``experiments/recognition/`` folder. Commands for reproducing pre-trained models can be found in the table.
...@@ -35,23 +35,32 @@ ResNeSt Backbone Models ...@@ -35,23 +35,32 @@ ResNeSt Backbone Models
============================================================================== ============== ============== ========================================================================================================= ============================================================================== ============== ============== =========================================================================================================
Model pixAcc mIoU Command Model pixAcc mIoU Command
============================================================================== ============== ============== ========================================================================================================= ============================================================================== ============== ============== =========================================================================================================
FCN_ResNeSt50_ADE xx.xx% xx.xx% :raw-html:`<a href="javascript:toggleblock('cmd_fcn_nest50_ade')" class="toggleblock">cmd</a>` FCN_ResNeSt50_ADE 80.18% 42.94% :raw-html:`<a href="javascript:toggleblock('cmd_fcn_nest50_ade')" class="toggleblock">cmd</a>`
DeepLabV3_ResNeSt50_ADE 81.17% 45.12% :raw-html:`<a href="javascript:toggleblock('cmd_deeplab_resnest50_ade')" class="toggleblock">cmd</a>` DeepLabV3_ResNeSt50_ADE 81.17% 45.12% :raw-html:`<a href="javascript:toggleblock('cmd_deeplab_resnest50_ade')" class="toggleblock">cmd</a>`
DeepLabV3_ResNeSt101_ADE 82.07% 46.91% :raw-html:`<a href="javascript:toggleblock('cmd_deeplab_resnest101_ade')" class="toggleblock">cmd</a>` DeepLabV3_ResNeSt101_ADE 82.07% 46.91% :raw-html:`<a href="javascript:toggleblock('cmd_deeplab_resnest101_ade')" class="toggleblock">cmd</a>`
DeepLabV3_ResNeSt269_ADE 82.62% 47.60% :raw-html:`<a href="javascript:toggleblock('cmd_deeplab_resnest269_ade')" class="toggleblock">cmd</a>`
============================================================================== ============== ============== ========================================================================================================= ============================================================================== ============== ============== =========================================================================================================
.. raw:: html .. raw:: html
<code xml:space="preserve" id="cmd_fcn_nest50_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_fcn_nest50_ade" style="display: none; text-align: left; white-space: pre-wrap">
python train.py --dataset ade20k --model fcn --aux --backbone resnest50 --batch-size 2 python train_dist.py --dataset ADE20K --model fcn --aux --backbone resnest50
</code>
<code xml:space="preserve" id="cmd_enc_nest50_ade" style="display: none; text-align: left; white-space: pre-wrap">
python train_dist.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnest50
</code> </code>
<code xml:space="preserve" id="cmd_deeplab_resnest50_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_deeplab_resnest50_ade" style="display: none; text-align: left; white-space: pre-wrap">
python train.py --dataset ADE20K --model deeplab --aux --backbone resnest50 python train_dist.py --dataset ADE20K --model deeplab --aux --backbone resnest50
</code> </code>
<code xml:space="preserve" id="cmd_deeplab_resnest101_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_deeplab_resnest101_ade" style="display: none; text-align: left; white-space: pre-wrap">
python train.py --dataset ADE20K --model deeplab --aux --backbone resnest101 python train_dist.py --dataset ADE20K --model deeplab --aux --backbone resnest101
</code>
<code xml:space="preserve" id="cmd_deeplab_resnest269_ade" style="display: none; text-align: left; white-space: pre-wrap">
python train_dist.py --dataset ADE20K --model deeplab --aux --backbone resnest269
</code> </code>
...@@ -73,19 +82,19 @@ EncNet_ResNet101s_ADE ...@@ -73,19 +82,19 @@ EncNet_ResNet101s_ADE
.. raw:: html .. raw:: html
<code xml:space="preserve" id="cmd_fcn50_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_fcn50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model FCN CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset ADE20K --model FCN
</code> </code>
<code xml:space="preserve" id="cmd_psp50_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_psp50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model PSP --aux CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset ADE20K --model PSP --aux
</code> </code>
<code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset ADE20K --model EncNet --aux --se-loss
</code> </code>
<code xml:space="preserve" id="cmd_enc101_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc101_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101 --base-size 640 --crop-size 576 CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101
</code> </code>
Pascal Context Dataset Pascal Context Dataset
...@@ -94,22 +103,22 @@ Pascal Context Dataset ...@@ -94,22 +103,22 @@ Pascal Context Dataset
============================================================================== ================= ============== ============================================================================================= ============================================================================== ================= ============== =============================================================================================
Model pixAcc mIoU Command Model pixAcc mIoU Command
============================================================================== ================= ============== ============================================================================================= ============================================================================== ================= ============== =============================================================================================
Encnet_ResNet50_PContext 79.2% 51.0% :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` Encnet_ResNet50s_PContext 79.2% 51.0% :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>`
EncNet_ResNet101_PContext 80.7% 54.1% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` EncNet_ResNet101s_PContext 80.7% 54.1% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>`
============================================================================== ================= ============== ============================================================================================= ============================================================================== ================= ============== =============================================================================================
.. raw:: html .. raw:: html
<code xml:space="preserve" id="cmd_fcn50_pcont" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_fcn50_pcont" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model FCN CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset PContext --model FCN
</code> </code>
<code xml:space="preserve" id="cmd_enc50_pcont" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc50_pcont" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model EncNet --aux --se-loss CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset PContext --model EncNet --aux --se-loss
</code> </code>
<code xml:space="preserve" id="cmd_enc101_pcont" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc101_pcont" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model EncNet --aux --se-loss --backbone resnet101 CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset PContext --model EncNet --aux --se-loss --backbone resnet101
</code> </code>
...@@ -127,9 +136,9 @@ EncNet_ResNet101s_VOC ...@@ -127,9 +136,9 @@ EncNet_ResNet101s_VOC
<code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap">
# First finetuning COCO dataset pretrained model on augmented set # First finetuning COCO dataset pretrained model on augmented set
# You can also train from scratch on COCO by yourself # You can also train from scratch on COCO by yourself
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 --ft CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 --ft
# Finetuning on original set # Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params --ft CUDA_VISIBLE_DEVICES=0,1,2,3 python train_dist.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params --ft
</code> </code>
...@@ -146,6 +155,28 @@ Test Pretrained ...@@ -146,6 +155,28 @@ Test Pretrained
python test.py --dataset ADE20K --model-zoo EncNet_ResNet50s_ADE --eval python test.py --dataset ADE20K --model-zoo EncNet_ResNet50s_ADE --eval
# pixAcc: 0.801, mIoU: 0.415: 100%|████████████████████████| 250/250 # pixAcc: 0.801, mIoU: 0.415: 100%|████████████████████████| 250/250
Train Your Own Model
--------------------
- Prepare the datasets by runing the scripts in the ``scripts/`` folder, for example preparing ``ADE20K`` dataset::
python scripts/prepare_ade20k.py
- The training script is in the ``experiments/segmentation/`` folder, example training command::
python train_dist.py --dataset ade20k --model encnet --aux --se-loss
- Detail training options, please run ``python train_dist.py -h``. Commands for reproducing pre-trained models can be found in the table.
.. hint::
The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::
python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval
Quick Demo Quick Demo
~~~~~~~~~~ ~~~~~~~~~~
...@@ -155,7 +186,7 @@ Quick Demo ...@@ -155,7 +186,7 @@ Quick Demo
import encoding import encoding
# Get the model # Get the model
model = encoding.models.get_model('Encnet_ResNet50_PContext', pretrained=True).cuda() model = encoding.models.get_model('Encnet_ResNet50s_PContext', pretrained=True).cuda()
model.eval() model.eval()
# Prepare the image # Prepare the image
...@@ -180,30 +211,21 @@ Quick Demo ...@@ -180,30 +211,21 @@ Quick Demo
.. image:: https://raw.githubusercontent.com/zhanghang1989/image-data/master/encoding/segmentation/pcontext/2010_001829.png .. image:: https://raw.githubusercontent.com/zhanghang1989/image-data/master/encoding/segmentation/pcontext/2010_001829.png
:width: 45% :width: 45%
Train Your Own Model
--------------------
- Prepare the datasets by runing the scripts in the ``scripts/`` folder, for example preparing ``ADE20K`` dataset::
python scripts/prepare_ade20k.py
- The training script is in the ``experiments/segmentation/`` folder, example training command::
python train_dist.py --dataset ade20k --model encnet --aux --se-loss
- Detail training options, please run ``python train.py -h``. Commands for reproducing pre-trained models can be found in the table.
.. hint::
The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::
python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval
Citation Citation
-------- --------
.. note:: .. note::
* Hang Zhang et al. "ResNeSt: Split-Attention Networks" *arXiv 2020*::
@article{zhang2020resnest,
title={ResNeSt: Split-Attention Networks},
author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander},
journal={arXiv preprint arXiv:2004.08955},
year={2020}
}
* Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. "Context Encoding for Semantic Segmentation" *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*:: * Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. "Context Encoding for Semantic Segmentation" *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*::
@InProceedings{Zhang_2018_CVPR, @InProceedings{Zhang_2018_CVPR,
......
...@@ -44,7 +44,14 @@ Citations ...@@ -44,7 +44,14 @@ Citations
--------- ---------
.. note:: .. note::
If using the code in your research, please cite our papers. * Hang Zhang et al. "ResNeSt: Split-Attention Networks" *arXiv 2020*::
@article{zhang2020resnest,
title={ResNeSt: Split-Attention Networks},
author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander},
journal={arXiv preprint arXiv:2004.08955},
year={2020}
}
* Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. "Context Encoding for Semantic Segmentation" *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*:: * Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. "Context Encoding for Semantic Segmentation" *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*::
......
...@@ -23,7 +23,7 @@ def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -23,7 +23,7 @@ def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=False, **kwargs) avd=True, avd_first=False, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('resnest50', root=root)), strict=False) get_model_file('resnest50', root=root)), strict=True)
return model return model
def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -33,7 +33,7 @@ def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -33,7 +33,7 @@ def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=False, **kwargs) avd=True, avd_first=False, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('resnest101', root=root)), strict=False) get_model_file('resnest101', root=root)), strict=True)
return model return model
def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -53,7 +53,7 @@ def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -53,7 +53,7 @@ def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=False, **kwargs) avd=True, avd_first=False, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('resnest269', root=root)), strict=False) get_model_file('resnest269', root=root)), strict=True)
return model return model
def resnest50_fast(pretrained=False, root='~/.encoding/models', **kwargs): def resnest50_fast(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -63,7 +63,7 @@ def resnest50_fast(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -63,7 +63,7 @@ def resnest50_fast(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=True, **kwargs) avd=True, avd_first=True, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('resnest50fast', root=root)), strict=False) get_model_file('resnest50fast', root=root)), strict=True)
return model return model
def resnest101_fast(pretrained=False, root='~/.encoding/models', **kwargs): def resnest101_fast(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -73,5 +73,5 @@ def resnest101_fast(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -73,5 +73,5 @@ def resnest101_fast(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=True, **kwargs) avd=True, avd_first=True, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('resnest101fast', root=root)), strict=False) get_model_file('resnest101fast', root=root)), strict=True)
return model return model
...@@ -13,6 +13,8 @@ _model_sha1 = {name: checksum for checksum, name in [ ...@@ -13,6 +13,8 @@ _model_sha1 = {name: checksum for checksum, name in [
('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'), ('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'),
('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'), ('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'),
('51ae5f19032e22af4ec08e695496547acdba5ce5', 'resnest269'), ('51ae5f19032e22af4ec08e695496547acdba5ce5', 'resnest269'),
# rectified
#('9b5dc32b3b36ca1a6b41ecd4906830fc84dae8ed', 'resnet101_rt'),
# resnet other variants # resnet other variants
('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'), ('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'),
('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'), ('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'),
...@@ -22,15 +24,17 @@ _model_sha1 = {name: checksum for checksum, name in [ ...@@ -22,15 +24,17 @@ _model_sha1 = {name: checksum for checksum, name in [
('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'), ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'),
# deepten paper # deepten paper
('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'), ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
# segmentation models # segmentation resnet models
('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50s_ade'), ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50s_ade'),
('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50s_pcontext'), ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50s_pcontext'),
('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101s_pcontext'), ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101s_pcontext'),
('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50s_ade'), ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50s_ade'),
('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101s_ade'), ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101s_ade'),
# resnest segmentation models # resnest segmentation models
('4aba491aaf8e4866a9c9981b210e3e3266ac1f2a', 'fcn_resnest50_ade'),
('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'), ('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'),
('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'), ('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'),
('0074dd10a6e6696f6f521653fb98224e75955496', 'deeplab_resnest269_ade'),
]} ]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
......
...@@ -29,7 +29,7 @@ models = { ...@@ -29,7 +29,7 @@ models = {
'wideresnet50': wideresnet50, 'wideresnet50': wideresnet50,
# deepten paper # deepten paper
'deepten_resnet50_minc': get_deepten_resnet50_minc, 'deepten_resnet50_minc': get_deepten_resnet50_minc,
# segmentation models # segmentation resnet models
'encnet_resnet101s_coco': get_encnet_resnet101_coco, 'encnet_resnet101s_coco': get_encnet_resnet101_coco,
'fcn_resnet50s_pcontext': get_fcn_resnet50_pcontext, 'fcn_resnet50s_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50s_pcontext': get_encnet_resnet50_pcontext, 'encnet_resnet50s_pcontext': get_encnet_resnet50_pcontext,
...@@ -38,8 +38,12 @@ models = { ...@@ -38,8 +38,12 @@ models = {
'encnet_resnet101s_ade': get_encnet_resnet101_ade, 'encnet_resnet101s_ade': get_encnet_resnet101_ade,
'fcn_resnet50s_ade': get_fcn_resnet50_ade, 'fcn_resnet50s_ade': get_fcn_resnet50_ade,
'psp_resnet50s_ade': get_psp_resnet50_ade, 'psp_resnet50s_ade': get_psp_resnet50_ade,
# segmentation resnest models
'fcn_resnest50_ade': get_fcn_resnest50_ade,
'deeplab_resnest50_ade': get_deeplab_resnest50_ade, 'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
'deeplab_resnest101_ade': get_deeplab_resnest101_ade, 'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
'deeplab_resnest200_ade': get_deeplab_resnest200_ade,
'deeplab_resnest269_ade': get_deeplab_resnest269_ade,
} }
model_list = list(models.keys()) model_list = list(models.keys())
...@@ -61,7 +65,6 @@ def get_model(name, **kwargs): ...@@ -61,7 +65,6 @@ def get_model(name, **kwargs):
Module: Module:
The model. The model.
""" """
name = name.lower() name = name.lower()
if name not in models: if name not in models:
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys())))) raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
......
...@@ -108,7 +108,7 @@ class ATTENHead(nn.Module): ...@@ -108,7 +108,7 @@ class ATTENHead(nn.Module):
if with_enc: if with_enc:
self.encmodule = EncModule(inter_channels+extended_channels, out_channels, ncodes=32, self.encmodule = EncModule(inter_channels+extended_channels, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer) se_loss=se_loss, norm_layer=norm_layer)
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), self.conv6 = nn.Sequential(nn.Dropout(0.1, False),
nn.Conv2d(inter_channels+extended_channels, out_channels, 1)) nn.Conv2d(inter_channels+extended_channels, out_channels, 1))
def forward(self, *inputs): def forward(self, *inputs):
......
...@@ -66,7 +66,7 @@ class DeepLabV3Head(nn.Module): ...@@ -66,7 +66,7 @@ class DeepLabV3Head(nn.Module):
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels), norm_layer(inter_channels),
nn.ReLU(True), nn.ReLU(True),
nn.Dropout2d(0.1, False), nn.Dropout(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1)) nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x): def forward(self, x):
...@@ -198,3 +198,42 @@ def get_deeplab_resnest101_ade(pretrained=False, root='~/.encoding/models', **kw ...@@ -198,3 +198,42 @@ def get_deeplab_resnest101_ade(pretrained=False, root='~/.encoding/models', **kw
>>> print(model) >>> print(model)
""" """
return get_deeplab('ade20k', 'resnest101', pretrained, aux=True, root=root, **kwargs) return get_deeplab('ade20k', 'resnest101', pretrained, aux=True, root=root, **kwargs)
def get_deeplab_resnest200_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""DeepLabV3 model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest200_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest200', pretrained, aux=True, root=root, **kwargs)
def get_deeplab_resnest269_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""DeepLabV3 model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest200_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest269', pretrained, aux=True, root=root, **kwargs)
...@@ -98,7 +98,7 @@ class EncHead(nn.Module): ...@@ -98,7 +98,7 @@ class EncHead(nn.Module):
nn.ReLU(inplace=True)) nn.ReLU(inplace=True))
self.encmodule = EncModule(512, out_channels, ncodes=32, self.encmodule = EncModule(512, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer) se_loss=se_loss, norm_layer=norm_layer)
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), self.conv6 = nn.Sequential(nn.Dropout(0.1, False),
nn.Conv2d(512, out_channels, 1)) nn.Conv2d(512, out_channels, 1))
def forward(self, *inputs): def forward(self, *inputs):
......
...@@ -84,7 +84,7 @@ class FCFPNHead(nn.Module): ...@@ -84,7 +84,7 @@ class FCFPNHead(nn.Module):
self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False), self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False),
norm_layer(512), norm_layer(512),
nn.ReLU(), nn.ReLU(),
nn.Dropout2d(0.1, False), nn.Dropout(0.1, False),
nn.Conv2d(512, out_channels, 1)) nn.Conv2d(512, out_channels, 1))
def forward(self, *inputs): def forward(self, *inputs):
......
...@@ -13,7 +13,8 @@ from ...nn import ConcurrentModule, SyncBatchNorm ...@@ -13,7 +13,8 @@ from ...nn import ConcurrentModule, SyncBatchNorm
from .base import BaseNet from .base import BaseNet
__all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_pcontext', 'get_fcn_resnet50_ade'] __all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_pcontext', 'get_fcn_resnet50_ade',
'get_fcn_resnest50_ade']
class FCN(BaseNet): class FCN(BaseNet):
r"""Fully Convolutional Networks for Semantic Segmentation r"""Fully Convolutional Networks for Semantic Segmentation
...@@ -97,13 +98,13 @@ class FCNHead(nn.Module): ...@@ -97,13 +98,13 @@ class FCNHead(nn.Module):
GlobalPooling(inter_channels, inter_channels, GlobalPooling(inter_channels, inter_channels,
norm_layer, self._up_kwargs), norm_layer, self._up_kwargs),
]), ]),
nn.Dropout2d(0.1, False), nn.Dropout(0.1, False),
nn.Conv2d(2*inter_channels, out_channels, 1)) nn.Conv2d(2*inter_channels, out_channels, 1))
else: else:
self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels), norm_layer(inter_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout2d(0.1, False), nn.Dropout(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1)) nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x): def forward(self, x):
...@@ -173,3 +174,23 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -173,3 +174,23 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> print(model) >>> print(model)
""" """
return get_fcn('ade20k', 'resnet50s', pretrained, root=root, **kwargs) return get_fcn('ade20k', 'resnet50s', pretrained, root=root, **kwargs)
def get_fcn_resnest50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model)
"""
kwargs['aux'] = True
return get_fcn('ade20k', 'resnest50', pretrained, root=root, **kwargs)
...@@ -44,7 +44,7 @@ class PSPHead(nn.Module): ...@@ -44,7 +44,7 @@ class PSPHead(nn.Module):
nn.Conv2d(in_channels * 2, inter_channels, 3, padding=1, bias=False), nn.Conv2d(in_channels * 2, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels), norm_layer(inter_channels),
nn.ReLU(True), nn.ReLU(True),
nn.Dropout2d(0.1, False), nn.Dropout(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1)) nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x): def forward(self, x):
......
...@@ -13,13 +13,36 @@ import torch ...@@ -13,13 +13,36 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.autograd import Variable from torch.autograd import Variable
from .splat import SplAtConv2d
from .rectify import RFConv2d
torch_ver = torch.__version__[:3] torch_ver = torch.__version__[:3]
__all__ = ['GlobalAvgPool2d', 'GramMatrix', __all__ = ['ConvBnAct', 'GlobalAvgPool2d', 'GramMatrix',
'View', 'Sum', 'Mean', 'Normalize', 'ConcurrentModule', 'View', 'Sum', 'Mean', 'Normalize', 'ConcurrentModule',
'PyramidPooling'] 'PyramidPooling']
class ConvBnAct(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, radix=0, groups=1,
bias=True, padding_mode='zeros',
rectify=False, rectify_avg=False, act=True,
norm_layer=nn.BatchNorm2d):
super().__init__()
if radix > 0:
conv_layer = SplAtConv2d
conv_kwargs = {'radix': radix, 'rectify': rectify, 'rectify_avg': rectify_avg, 'norm_layer': norm_layer}
else:
conv_layer = RFConv2d if rectify else nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectify else {}
self.add_module("conv", conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias,
padding_mode=padding_mode, **conv_kwargs))
self.add_module("bn", nn.BatchNorm2d(out_channels))
if act:
self.add_module("relu", nn.ReLU())
class GlobalAvgPool2d(nn.Module): class GlobalAvgPool2d(nn.Module):
def __init__(self): def __init__(self):
"""Global average pooling over the input's spatial dimensions""" """Global average pooling over the input's spatial dimensions"""
...@@ -28,6 +51,7 @@ class GlobalAvgPool2d(nn.Module): ...@@ -28,6 +51,7 @@ class GlobalAvgPool2d(nn.Module):
def forward(self, inputs): def forward(self, inputs):
return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)
class GramMatrix(nn.Module): class GramMatrix(nn.Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
...@@ -41,6 +65,7 @@ class GramMatrix(nn.Module): ...@@ -41,6 +65,7 @@ class GramMatrix(nn.Module):
gram = features.bmm(features_t) / (ch * h * w) gram = features.bmm(features_t) / (ch * h * w)
return gram return gram
class View(nn.Module): class View(nn.Module):
"""Reshape the input into different size, an inplace operator, support """Reshape the input into different size, an inplace operator, support
SelfParallel mode. SelfParallel mode.
......
...@@ -58,9 +58,9 @@ class DropBlock2D(nn.Module): ...@@ -58,9 +58,9 @@ class DropBlock2D(nn.Module):
# sample mask and place on input device # sample mask and place on input device
if self.share_channel: if self.share_channel:
mask = (torch.rand(x.shape[0], *x.shape[2:], device=x.device, dtype=x.dtype) < gamma).squeeze(1) mask = (torch.rand(*x.shape[2:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0).unsqueeze(0)
else: else:
mask = (torch.rand(*x.shape, device=x.device, dtype=x.dtype) < gamma) mask = (torch.rand(*x.shape[1:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0)
# compute block mask # compute block mask
block_mask, keeped = self._compute_block_mask(mask) block_mask, keeped = self._compute_block_mask(mask)
......
...@@ -6,10 +6,10 @@ import torch.nn.functional as F ...@@ -6,10 +6,10 @@ import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..nn import RFConv2d from .rectify import RFConv2d
from .dropblock import DropBlock2D from .dropblock import DropBlock2D
__all__ = ['SKConv2d'] __all__ = ['SplAtConv2d']
class SplAtConv2d(Module): class SplAtConv2d(Module):
"""Split-Attention Conv2d """Split-Attention Conv2d
...@@ -42,6 +42,7 @@ class SplAtConv2d(Module): ...@@ -42,6 +42,7 @@ class SplAtConv2d(Module):
self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
if dropblock_prob > 0.0: if dropblock_prob > 0.0:
self.dropblock = DropBlock2D(dropblock_prob, 3) self.dropblock = DropBlock2D(dropblock_prob, 3)
self.rsoftmax = rSoftMax(radix, groups)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -64,11 +65,8 @@ class SplAtConv2d(Module): ...@@ -64,11 +65,8 @@ class SplAtConv2d(Module):
gap = self.bn1(gap) gap = self.bn1(gap)
gap = self.relu(gap) gap = self.relu(gap)
atten = self.fc2(gap).view((batch, self.radix, self.channels)) atten = self.fc2(gap)
if self.radix > 1: atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1)
else:
atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1)
if self.radix > 1: if self.radix > 1:
atten = torch.split(atten, channel//self.radix, dim=1) atten = torch.split(atten, channel//self.radix, dim=1)
...@@ -76,3 +74,19 @@ class SplAtConv2d(Module): ...@@ -76,3 +74,19 @@ class SplAtConv2d(Module):
else: else:
out = atten * x out = atten * x
return out.contiguous() return out.contiguous()
class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
...@@ -16,3 +16,4 @@ from .train_helper import * ...@@ -16,3 +16,4 @@ from .train_helper import *
from .presets import load_image from .presets import load_image
from .files import * from .files import *
from .misc import * from .misc import *
from .dist_helper import *
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
__all__ = ['torch_dist_sum']
def torch_dist_sum(gpu, *args):
process_group = torch.distributed.group.WORLD
tensor_args = []
pending_res = []
for arg in args:
if isinstance(arg, torch.Tensor):
tensor_arg = arg.clone().reshape(-1).detach().cuda(gpu)
else:
tensor_arg = torch.tensor(arg).reshape(-1).cuda(gpu)
tensor_args.append(tensor_arg)
pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True))
for res in pending_res:
res.wait()
return tensor_args
...@@ -12,7 +12,8 @@ import threading ...@@ -12,7 +12,8 @@ import threading
import numpy as np import numpy as np
import torch import torch
__all__ = ['accuracy', 'SegmentationMetric', 'batch_intersection_union', 'batch_pix_accuracy', __all__ = ['accuracy', 'get_pixacc_miou',
'SegmentationMetric', 'batch_intersection_union', 'batch_pix_accuracy',
'pixel_accuracy', 'intersection_and_union'] 'pixel_accuracy', 'intersection_and_union']
def accuracy(output, target, topk=(1,)): def accuracy(output, target, topk=(1,)):
...@@ -31,6 +32,13 @@ def accuracy(output, target, topk=(1,)): ...@@ -31,6 +32,13 @@ def accuracy(output, target, topk=(1,)):
res.append(correct_k.mul_(100.0 / batch_size)) res.append(correct_k.mul_(100.0 / batch_size))
return res return res
def get_pixacc_miou(total_correct, total_label, total_inter, total_union):
pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
mIoU = IoU.mean()
return pixAcc, mIoU
class SegmentationMetric(object): class SegmentationMetric(object):
"""Computes pixAcc and mIoU metric scroes """Computes pixAcc and mIoU metric scroes
""" """
...@@ -66,11 +74,11 @@ class SegmentationMetric(object): ...@@ -66,11 +74,11 @@ class SegmentationMetric(object):
else: else:
raise NotImplemented raise NotImplemented
def get_all(self):
return self.total_correct, self.total_label, self.total_inter, self.total_union
def get(self): def get(self):
pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) return get_pixacc_miou(self.total_correct, self.total_label, self.total_inter, self.total_union)
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
mIoU = IoU.mean()
return pixAcc, mIoU
def reset(self): def reset(self):
self.total_inter = 0 self.total_inter = 0
...@@ -79,7 +87,6 @@ class SegmentationMetric(object): ...@@ -79,7 +87,6 @@ class SegmentationMetric(object):
self.total_label = 0 self.total_label = 0
return return
def batch_pix_accuracy(output, target): def batch_pix_accuracy(output, target):
"""Batch Pixel Accuracy """Batch Pixel Accuracy
Args: Args:
......
# Adapted from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/precise_bn.py
import itertools
from typing import Any, Iterable, List, Tuple, Type
import torch
from torch import nn
from ..nn import DistSyncBatchNorm, SyncBatchNorm
BN_MODULE_TYPES: Tuple[Type[nn.Module]] = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
DistSyncBatchNorm,
SyncBatchNorm,
)
@torch.no_grad()
def update_bn_stats(
model: nn.Module, data_loader: Iterable[Any], num_iters: int = 200 # pyre-ignore
) -> None:
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
Args:
model (nn.Module): the model whose bn stats will be recomputed.
Note that:
1. This function will not alter the training mode of the given model.
Users are responsible for setting the layers that needs
precise-BN to training mode, prior to calling this function.
2. Be careful if your models contain other stateful layers in
addition to BN, i.e. layers whose state can change in forward
iterations. This function will alter their state. If you wish
them unchanged, you need to either pass in a submodule without
those layers, or backup the states.
data_loader (iterator): an iterator. Produce data as inputs to the model.
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)
if len(bn_layers) == 0:
return
# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers] # pyre-ignore
for bn in bn_layers:
bn.momentum = 1.0
# Note that running_var actually means "running average of variance"
running_mean = [
torch.zeros_like(bn.running_mean) for bn in bn_layers # pyre-ignore
]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] # pyre-ignore
ind = -1
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
inputs=inputs.cuda()
with torch.no_grad(): # No need to backward
model(inputs)
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations.
assert ind == num_iters - 1, (
"update_bn_stats is meant to run for {} iterations, "
"but the dataloader stops at {} iterations.".format(num_iters, ind)
)
for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]
def get_bn_modules(model: nn.Module) -> List[nn.Module]:
"""
Find all BatchNorm (BN) modules that are in training mode. See
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.
Args:
model (nn.Module): a model possibly containing BN modules.
Returns:
list[nn.Module]: all BN modules in the model.
"""
# Finds all the bn layers.
bn_layers = [
m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment