Unverified Commit 17be9e16 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

fix miscs (#258)

parent b872eb8c
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Unit Test
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- uses: seanmiddleditch/gha-setup-ninja@master
- name: Set up Python
uses: actions/github-script@0.3.0
with:
github-token: ${{github.token}}
script: |
const core = require('@actions/core')
core.exportVariable("PATH", "/home/ubuntu/anaconda3/bin:/usr/local/bin:/usr/bin/:/bin:$PATH")
- name: Install package
run: |
python -m pip install --upgrade pip
pip install -e .
- name: Run pytest
run: |
pip install nose
nosetests -v tests/unit_test/
......@@ -4,6 +4,8 @@
[![Downloads](http://pepy.tech/badge/torch-encoding)](http://pepy.tech/project/torch-encoding)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Build Docs](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Build%20Docs/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions)
[![Unit Test](https://github.com/zhanghang1989/PyTorch-Encoding/workflows/Unit%20Test/badge.svg)](https://github.com/zhanghang1989/PyTorch-Encoding/actions)
# PyTorch-Encoding
created by [Hang Zhang](http://hangzh.com/)
......
......@@ -14,14 +14,14 @@ Get Pre-trained Model
---------------------
.. hint::
The model names contain the training information. For instance ``FCN_ResNet50_PContext``:
- ``FCN`` indicate the algorithm is Fully Convolutional Network for Semantic Segmentation
The model names contain the training information. For instance ``EncNet_ResNet50s_ADE``:
- ``EncNet`` indicate the algorithm is Context Encoding for Semantic Segmentation
- ``ResNet50`` is the name of backbone network.
- ``PContext`` means the PASCAL in Context dataset.
- ``ADE`` means the ADE20K dataset.
How to get pretrained model, for example ``FCN_ResNet50_PContext``::
How to get pretrained model, for example ``EncNet_ResNet50s_ADE``::
model = encoding.models.get_model('FCN_ResNet50_PContext', pretrained=True)
model = encoding.models.get_model('EncNet_ResNet50s_ADE', pretrained=True)
After clicking ``cmd`` in the table, the command for training the model can be found below the table.
......@@ -64,10 +64,9 @@ ADE20K Dataset
============================================================================== ================= ============== =============================================================================================
Model pixAcc mIoU Command
============================================================================== ================= ============== =============================================================================================
FCN_ResNet50_ADE 78.7% 38.5% :raw-html:`<a href="javascript:toggleblock('cmd_fcn50_ade')" class="toggleblock">cmd</a>`
EncNet_ResNet50_ADE 80.1% 41.5% :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>`
EncNet_ResNet101_ADE 81.3% 44.4% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_ade')" class="toggleblock">cmd</a>`
EncNet_ResNet101_VOC N/A 85.9% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_voc')" class="toggleblock">cmd</a>`
FCN_ResNet50s_ADE 78.7% 38.5% :raw-html:`<a href="javascript:toggleblock('cmd_fcn50_ade')" class="toggleblock">cmd</a>`
EncNet_ResNet50s_ADE 80.1% 41.5% :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>`
EncNet_ResNet101s_ADE 81.3% 44.4% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_ade')" class="toggleblock">cmd</a>`
============================================================================== ================= ============== =============================================================================================
......@@ -89,16 +88,6 @@ EncNet_ResNet101_VOC
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101 --base-size 640 --crop-size 576
</code>
<code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap">
# First finetuning COCO dataset pretrained model on augmented set
# You can also train from scratch on COCO by yourself
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 --ft
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params --ft
</code>
Pascal Context Dataset
~~~~~~~~~~~~~~~~~~~~~~
......@@ -124,18 +113,38 @@ EncNet_ResNet101_PContext
</code>
Pascal VOC Dataset
~~~~~~~~~~~~~~~~~~
============================================================================== ================= ============== =============================================================================================
Model pixAcc mIoU Command
============================================================================== ================= ============== =============================================================================================
EncNet_ResNet101s_VOC N/A 85.9% :raw-html:`<a href="javascript:toggleblock('cmd_enc101_voc')" class="toggleblock">cmd</a>`
============================================================================== ================= ============== =============================================================================================
.. raw:: html
<code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap">
# First finetuning COCO dataset pretrained model on augmented set
# You can also train from scratch on COCO by yourself
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 --ft
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params --ft
</code>
Test Pretrained
~~~~~~~~~~~~~~~
- Prepare the datasets by runing the scripts in the ``scripts/`` folder, for example preparing ``PASCAL Context`` dataset::
python scripts/prepare_pcontext.py
python scripts/prepare_ade20k.py
- The test script is in the ``experiments/segmentation/`` folder. For evaluating the model (using MS),
for example ``Encnet_ResNet50_PContext``::
for example ``EncNet_ResNet50s_ADE``::
python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval
# pixAcc: 0.792, mIoU: 0.510: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]
python test.py --dataset ADE20K --model-zoo EncNet_ResNet50s_ADE --eval
# pixAcc: 0.801, mIoU: 0.415: 100%|████████████████████████| 250/250
Quick Demo
~~~~~~~~~~
......
......@@ -34,7 +34,7 @@ Assuming CUDA and cudnn are already sucessfully installed, otherwise please refe
* Install PyTorch::
conda install pytorch torchvision cudatoolkit=100 -c pytorch
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
* Install this package::
......
......@@ -16,18 +16,20 @@ from .. import lib
__all__ = ['moments', 'syncbatchnorm', 'inp_syncbatchnorm']
class moments(Function):
class moments_(Function):
@staticmethod
def forward(ctx, x):
if x.is_cuda:
ex, ex2 = lib.gpu.expectation_forward(x)
else:
raise NotImplemented
ctx.save_for_backward(x)
return ex, ex2
@staticmethod
def backward(ctx, dex, dex2):
if x.is_cuda:
x, = ctx.saved_tensors
if dex.is_cuda:
dx = lib.gpu.expectation_backward(x, dex, dex2)
else:
raise NotImplemented
......@@ -295,5 +297,6 @@ class inp_syncbatchnorm_(Function):
ctx.master_queue = extra["master_queue"]
ctx.worker_queue = extra["worker_queue"]
moments = moments_.apply
syncbatchnorm = syncbatchnorm_.apply
inp_syncbatchnorm = inp_syncbatchnorm_.apply
from .model_zoo import get_model
from .model_store import get_model_file
from .model_zoo import model_list
from .model_store import get_model_file, pretrained_model_list
from .sseg import get_segmentation_model, MultiEvalModule
......@@ -43,7 +43,7 @@ def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest152', root=root)), strict=False)
get_model_file('resnest200', root=root)), strict=False)
return model
def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
......
......@@ -18,7 +18,7 @@ def resnet50s(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet50', root=root)), strict=False)
get_model_file('resnet50s', root=root)), strict=False)
return model
def resnet101s(pretrained=False, root='~/.encoding/models', **kwargs):
......@@ -31,7 +31,7 @@ def resnet101s(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet101', root=root)), strict=False)
get_model_file('resnet101s', root=root)), strict=False)
return model
def resnet152s(pretrained=False, root='~/.encoding/models', **kwargs):
......@@ -44,7 +44,7 @@ def resnet152s(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet152', root=root)), strict=False)
get_model_file('resnet152s', root=root)), strict=False)
return model
# ResNet-D
......
......@@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from ..nn import Encoding, View, Normalize
from .backbone import resnet
from .backbone import resnet50s, resnet101s, resnet152s
__all__ = ['DeepTen', 'get_deepten', 'get_deepten_resnet50_minc']
......@@ -22,11 +22,11 @@ class DeepTen(nn.Module):
self.backbone = backbone
# copying modules from pretrained models
if self.backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=False)
self.pretrained = resnet50s(pretrained=True, dilated=False)
elif self.backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=False)
self.pretrained = resnet101s(pretrained=True, dilated=False)
elif self.backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=False)
self.pretrained = resnet152s(pretrained=True, dilated=False)
else:
raise RuntimeError('unknown backbone: {}'.format(self.backbone))
n_codes = 32
......
......@@ -8,17 +8,11 @@ import portalocker
from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
# resnet
('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
# rectified
('9b5dc32b3b36ca1a6b41ecd4906830fc84dae8ed', 'resnet101_rt'),
# resnest
('fb9de5b360976e3e8bd3679d3e93c5409a5eff3c', 'resnest50'),
('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'),
('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'),
('b743074c6fc40f88d7f53e8affb350de38f4f49d', 'resnest269'),
('51ae5f19032e22af4ec08e695496547acdba5ce5', 'resnest269'),
# resnet other variants
('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'),
('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'),
......@@ -29,15 +23,11 @@ _model_sha1 = {name: checksum for checksum, name in [
# deepten paper
('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
# segmentation models
('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'),
('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'),
('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'),
('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'),
('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'),
('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'),
('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50s_ade'),
('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50s_pcontext'),
('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101s_pcontext'),
('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50s_ade'),
('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101s_ade'),
# resnest segmentation models
('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'),
('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'),
......
......@@ -4,7 +4,45 @@ from .backbone import *
from .sseg import *
from .deepten import *
__all__ = ['get_model']
__all__ = ['model_list', 'get_model']
models = {
# resnet
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
# resnest
'resnest50': resnest50,
'resnest101': resnest101,
'resnest200': resnest200,
'resnest269': resnest269,
# resnet other variants
'resnet50s': resnet50s,
'resnet101s': resnet101s,
'resnet152s': resnet152s,
'resnet50d': resnet50d,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
# other segmentation backbones
'xception65': xception65,
'wideresnet38': wideresnet38,
'wideresnet50': wideresnet50,
# deepten paper
'deepten_resnet50_minc': get_deepten_resnet50_minc,
# segmentation models
'encnet_resnet101s_coco': get_encnet_resnet101_coco,
'fcn_resnet50s_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50s_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101s_pcontext': get_encnet_resnet101_pcontext,
'encnet_resnet50s_ade': get_encnet_resnet50_ade,
'encnet_resnet101s_ade': get_encnet_resnet101_ade,
'fcn_resnet50s_ade': get_fcn_resnet50_ade,
'psp_resnet50s_ade': get_psp_resnet50_ade,
'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
}
model_list = list(models.keys())
def get_model(name, **kwargs):
"""Returns a pre-defined model by name
......@@ -23,40 +61,7 @@ def get_model(name, **kwargs):
Module:
The model.
"""
models = {
# resnet
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
# resnest
'resnest50': resnest50,
'resnest101': resnest101,
'resnest200': resnest200,
'resnest269': resnest269,
# resnet other variants
'resnet50s': resnet50s,
'resnet101s': resnet101s,
'resnet152s': resnet152s,
'resnet50d': resnet50d,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
# other segmentation backbones
'xception65': xception65,
'wideresnet38': wideresnet38,
'wideresnet50': wideresnet50,
# deepten paper
'deepten_resnet50_minc': get_deepten_resnet50_minc,
# segmentation models
'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
'encnet_resnet50_ade': get_encnet_resnet50_ade,
'encnet_resnet101_ade': get_encnet_resnet101_ade,
'fcn_resnet50_ade': get_fcn_resnet50_ade,
'psp_resnet50_ade': get_psp_resnet50_ade,
'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
}
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
......
......@@ -178,7 +178,7 @@ def get_deeplab_resnest50_ade(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_deeplab_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest50', pretrained, root=root, **kwargs)
return get_deeplab('ade20k', 'resnest50', pretrained, aux=True, root=root, **kwargs)
def get_deeplab_resnest101_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""DeepLabV3 model from the paper `"Context Encoding for Semantic Segmentation"
......@@ -197,4 +197,4 @@ def get_deeplab_resnest101_ade(pretrained=False, root='~/.encoding/models', **kw
>>> model = get_deeplab_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest101', pretrained, root=root, **kwargs)
return get_deeplab('ade20k', 'resnest101', pretrained, aux=True, root=root, **kwargs)
......@@ -15,7 +15,7 @@ from ...nn import SyncBatchNorm, Encoding, Mean
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade',
'get_encnet_resnet101_ade']
'get_encnet_resnet101_ade', 'get_encnet_resnet101_coco']
class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
......@@ -139,13 +139,13 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
from ...datasets import datasets, acronyms
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
from ..model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('encnet_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
......@@ -164,8 +164,28 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
return get_encnet('pcontext', 'resnet50s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet101_coco(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet_resnet101_coco(pretrained=True)
>>> print(model)
"""
return get_encnet('coco', 'resnet101s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, lateral=True, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
......@@ -185,7 +205,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
......@@ -201,11 +221,11 @@ def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwarg
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True,
return get_encnet('ade20k', 'resnet50s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet101_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
......@@ -225,7 +245,7 @@ def get_encnet_resnet101_ade(pretrained=False, root='~/.encoding/models', **kwar
base_size=640, crop_size=576, **kwargs)
def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
......
......@@ -140,7 +140,7 @@ def get_fcfpn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
from ...datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCFPN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
from ..model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('fcfpn_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
......
......@@ -131,7 +131,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
from ...datasets import datasets, acronyms
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
from ..model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('fcn_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
......
......@@ -56,7 +56,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
from ...datasets import datasets, acronyms
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
from ..model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('psp_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
......
......@@ -91,7 +91,7 @@ def get_upernet(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
from ...datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = UperNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
from ..model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('upernet_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
......
......@@ -168,7 +168,12 @@ class SyncBatchNorm(_BatchNorm):
# running_exs
#self.register_buffer('running_exs', torch.ones(num_features))
def _check_input_dim(self, x):
pass
def forward(self, x):
if not self.training:
return super().forward(x)
# Resize the input to (B, C, -1).
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
......
......@@ -11,7 +11,6 @@ import time
import argparse
import numpy as np
from tqdm import tqdm
from mpi4py import MPI
import torch
import torch.nn as nn
......@@ -114,6 +113,21 @@ best_pred = 0.0
acclist_train = []
acclist_val = []
def torch_dist_sum(gpu, *args):
process_group = torch.distributed.group.WORLD
tensor_args = []
pending_res = []
for arg in args:
if isinstance(arg, torch.Tensor):
tensor_arg = arg.clone().reshape(1).detach().cuda(gpu)
else:
tensor_arg = torch.tensor(arg).reshape(1).cuda(gpu)
tensor_args.append(tensor_arg)
pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True))
for res in pending_res:
res.wait()
return tensor_args
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
args.rank = args.rank * ngpus_per_node + gpu
......@@ -280,20 +294,11 @@ def main_worker(gpu, ngpus_per_node, args):
top1.update(acc1[0], data.size(0))
top5.update(acc5[0], data.size(0))
comm = MPI.COMM_WORLD
# send to master
sum1 = comm.gather(top1.sum, root=0)
cnt1 = comm.gather(top1.count, root=0)
sum5 = comm.gather(top5.sum, root=0)
cnt5 = comm.gather(top5.count, root=0)
# get back from master
sum1 = comm.bcast(sum1, root=0)
cnt1 = comm.bcast(cnt1, root=0)
sum5 = comm.bcast(sum5, root=0)
cnt5 = comm.bcast(cnt5, root=0)
# sum all
sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)
if args.gpu == 0:
top1_acc = sum(sum1) / sum(cnt1)
top5_acc = sum(sum5) / len(cnt5)
top5_acc = sum(sum5) / sum(cnt5)
print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
# save checkpoint
......
......@@ -7,7 +7,6 @@
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import print_function
import os
import argparse
from tqdm import tqdm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment