Unverified Commit 32e382bc authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

v0.4.3 (#71)

- ADE20K training model
- Amazon legal approval

fixes https://github.com/zhanghang1989/PyTorch-Encoding/issues/69
parent 9bc70531
MIT License MIT License
Copyright (c) 2017 Hang Zhang Copyright (c) 2017 Hang Zhang. All rights reserved.
Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
...@@ -9,8 +10,16 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell ...@@ -9,8 +10,16 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions: furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in 1. Redistributions of source code must retain the above copyright
all copies or substantial portions of the Software. notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of Amazon Inc nor the names of the contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
......
...@@ -27,22 +27,27 @@ Test Pre-trained Model ...@@ -27,22 +27,27 @@ Test Pre-trained Model
for example ``Encnet_ResNet50_PContext``:: for example ``Encnet_ResNet50_PContext``::
python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval
# pixAcc: 0.7862, mIoU: 0.4946: 100%|████████████████████████| 319/319 [09:44<00:00, 1.83s/it] # pixAcc: 0.7838, mIoU: 0.4958: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]
The command for training the model can be found by clicking ``cmd`` in the table. The command for training the model can be found by clicking ``cmd`` in the table.
.. role:: raw-html(raw) .. role:: raw-html(raw)
:format: html :format: html
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+ +----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Model | pixAcc | mIoU | Command | | Model | pixAcc | mIoU | Note | Command | Logs |
+==================================+===========+===========+==============================================================================================+ +==================================+===========+===========+===========+==============================================================================================+============+
| FCN_ResNet50_PContext | 76.0% | 45.7 | :raw-html:`<a href="javascript:toggleblock('cmd_fcn50_pcont')" class="toggleblock">cmd</a>` | | Encnet_ResNet50_PContext | 78.4% | 49.6% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+ +----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Encnet_ResNet50_PContext | 78.6% | 49.5 | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | | EncNet_ResNet101_PContext | 79.9% | 51.8% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+ +----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Encnet_ResNet101_PContext | 80.0% | 52.1 | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | | EncNet_ResNet50_ADE | 79.8% | 41.3% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+ +----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
.. _ENC50PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_pcontext.log?raw=true
.. _ENC101PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_pcontext.log?raw=true
.. _ENC50ADE: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_ade.log?raw=true
.. raw:: html .. raw:: html
...@@ -58,6 +63,14 @@ Test Pre-trained Model ...@@ -58,6 +63,14 @@ Test Pre-trained Model
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model EncNet --aux --se-loss --backbone resnet101 CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model EncNet --aux --se-loss --backbone resnet101
</code> </code>
<code xml:space="preserve" id="cmd_psp50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model PSP --aux
</code>
<code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss
</code>
Quick Demo Quick Demo
~~~~~~~~~~ ~~~~~~~~~~
...@@ -105,6 +118,12 @@ Train Your Own Model ...@@ -105,6 +118,12 @@ Train Your Own Model
- Detail training options, please run ``python train.py -h``. - Detail training options, please run ``python train.py -h``.
- The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::
CUDA_VISIBLE_DEVICES=0,1,2,3 python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval
Citation Citation
-------- --------
......
...@@ -75,9 +75,12 @@ def _get_ade20k_pairs(folder, split='train'): ...@@ -75,9 +75,12 @@ def _get_ade20k_pairs(folder, split='train'):
if split == 'train': if split == 'train':
img_folder = os.path.join(folder, 'images/training') img_folder = os.path.join(folder, 'images/training')
mask_folder = os.path.join(folder, 'annotations/training') mask_folder = os.path.join(folder, 'annotations/training')
else: elif split == 'val':
img_folder = os.path.join(folder, 'images/validation') img_folder = os.path.join(folder, 'images/validation')
mask_folder = os.path.join(folder, 'annotations/validation') mask_folder = os.path.join(folder, 'annotations/validation')
else:
img_folder = os.path.join(folder, 'images/trainval')
mask_folder = os.path.join(folder, 'annotations/trainval')
for filename in os.listdir(img_folder): for filename in os.listdir(img_folder):
basename, _ = os.path.splitext(filename) basename, _ = os.path.splitext(filename)
if filename.endswith(".jpg"): if filename.endswith(".jpg"):
......
import os
from tqdm import trange
from PIL import Image, ImageOps, ImageFilter
import numpy as np
import torch
from .base import BaseDataset
"""
NUM_CHANNEL = 91
[] background
[5] airplane
[2] bicycle
[16] bird
[9] boat
[44] bottle
[6] bus
[3] car
[17] cat
[62] chair
[21] cow
[67] dining table
[18] dog
[19] horse
[4] motorcycle
[1] person
[64] potted plant
[20] sheep
[63] couch
[7] train
[72] tv
"""
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
1, 64, 20, 63, 7, 72]
class COCOSegmentation(BaseDataset):
def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
mode=None, transform=None, target_transform=None):
super(COCOSegmentation, self).__init__(
root, split, mode, transform, target_transform)
from pycocotools.coco import COCO
from pycocotools import mask
if mode == 'train':
print('train set')
ann_file = os.path.join(root, 'coco/annotations/instances_train2014.json')
ids_file = os.path.join(root, 'coco/annotations/train_ids.pth')
root = os.path.join(root, 'coco/train2014')
else:
print('val set')
ann_file = os.path.join(root, 'coco/annotations/instances_val2014.json')
ids_file = os.path.join(root, 'coco/annotations/val_ids.pth')
root = os.path.join(root, 'coco/val2014')
self.coco = COCO(ann_file)
self.coco_mask = mask
if os.path.exists(ids_file):
self.ids = torch.load(ids_file)
else:
ids = list(self.coco.imgs.keys())
self.ids = self._preprocess(ids, ids_file)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
img_metadata = coco.loadImgs(img_id)[0]
path = img_metadata['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
mask = Image.fromarray(self._gen_seg_mask(cocotarget, img_metadata['height'],
img_metadata['width']))
# synchrosized transform
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
elif self.mode == 'val':
img, mask = self._val_sync_transform(img, mask)
else:
assert self.mode == 'testval'
mask = self._mask_transform(mask)
# general resize, normalize and toTensor
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
mask = self.target_transform(mask)
return img, mask
def __len__(self):
return len(self.ids)
def _gen_seg_mask(self, target, h, w):
mask = np.zeros((h, w), dtype=np.uint8)
coco_mask = self.coco_mask
for instance in target:
rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
m = coco_mask.decode(rle)
cat = instance['category_id']
if cat in CAT_LIST:
c = CAT_LIST.index(cat)
else:
continue
if len(m.shape) < 3:
mask[:, :] += (mask == 0) * (m * c)
else:
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
return mask
def _preprocess(self, ids, ids_file):
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
tbar = trange(len(ids))
new_ids = []
for i in tbar:
img_id = ids[i]
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
img_metadata = self.coco.loadImgs(img_id)[0]
mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
img_metadata['width'])
# more than 1k pixels
if (mask > 0).sum() > 1000:
new_ids.append(img_id)
tbar.set_description('Doing: {}/{}, got {} qualified images'.\
format(i, len(ids), len(new_ids)))
print('Found number of qualified images: ', len(new_ids))
torch.save(new_ids, ids_file)
return new_ids
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
from PIL import Image, ImageOps, ImageFilter from PIL import Image, ImageOps, ImageFilter
import os import os
import os.path
import math import math
import random import random
import numpy as np import numpy as np
from tqdm import trange
import torch import torch
from .base import BaseDataset from .base import BaseDataset
...@@ -26,20 +26,13 @@ class ContextSegmentation(BaseDataset): ...@@ -26,20 +26,13 @@ class ContextSegmentation(BaseDataset):
root = os.path.join(root, self.BASE_DIR) root = os.path.join(root, self.BASE_DIR)
annFile = os.path.join(root, 'trainval_merged.json') annFile = os.path.join(root, 'trainval_merged.json')
imgDir = os.path.join(root, 'JPEGImages') imgDir = os.path.join(root, 'JPEGImages')
mask_file = os.path.join(root, self.split+'.pth')
# training mode # training mode
if split == 'train': self.detail = Detail(annFile, imgDir, split)
phase = 'train'
elif split == 'val':
phase = 'val'
elif split == 'test':
phase = 'val'
#phase = 'test'
print('annFile', annFile)
print('imgDir', imgDir)
self.detail = Detail(annFile, imgDir, phase)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.ids = self.detail.getImgs() self.ids = self.detail.getImgs()
# generate masks
self._mapping = np.sort(np.array([ self._mapping = np.sort(np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22,
23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296, 23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296,
...@@ -47,6 +40,10 @@ class ContextSegmentation(BaseDataset): ...@@ -47,6 +40,10 @@ class ContextSegmentation(BaseDataset):
68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360, 68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360,
98, 187, 104, 105, 366, 189, 368, 113, 115])) 98, 187, 104, 105, 366, 189, 368, 113, 115]))
self._key = np.array(range(len(self._mapping))).astype('uint8') self._key = np.array(range(len(self._mapping))).astype('uint8')
if os.path.exists(mask_file):
self.masks = torch.load(mask_file)
else:
self.masks = self._preprocess(mask_file)
def _class_to_index(self, mask): def _class_to_index(self, mask):
# assert the values # assert the values
...@@ -57,19 +54,33 @@ class ContextSegmentation(BaseDataset): ...@@ -57,19 +54,33 @@ class ContextSegmentation(BaseDataset):
index = np.digitize(mask.ravel(), self._mapping, right=True) index = np.digitize(mask.ravel(), self._mapping, right=True)
return self._key[index].reshape(mask.shape) return self._key[index].reshape(mask.shape)
def _preprocess(self, mask_file):
masks = {}
tbar = trange(len(self.ids))
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
for i in tbar:
img_id = self.ids[i]
mask = Image.fromarray(self._class_to_index(
self.detail.getMask(img_id)))
masks[img_id['image_id']] = mask
tbar.set_description("Preprocessing masks {}".format(img_id['image_id']))
torch.save(masks, mask_file)
return masks
def __getitem__(self, index): def __getitem__(self, index):
detail = self.detail
img_id = self.ids[index] img_id = self.ids[index]
path = img_id['file_name'] path = img_id['file_name']
iid = img_id['image_id'] iid = img_id['image_id']
img = Image.open(os.path.join(detail.img_folder, path)).convert('RGB') img = Image.open(os.path.join(self.detail.img_folder, path)).convert('RGB')
if self.mode == 'test': if self.mode == 'test':
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
return img, os.path.basename(path) return img, os.path.basename(path)
# convert mask to 60 categories # convert mask to 60 categories
mask = Image.fromarray(self._class_to_index( #mask = Image.fromarray(self._class_to_index(
detail.getMask(img_id))) # self.detail.getMask(img_id)))
mask = self.masks[iid]
# synchrosized transform # synchrosized transform
if self.mode == 'train': if self.mode == 'train':
img, mask = self._sync_transform(img, mask) img, mask = self._sync_transform(img, mask)
......
"""Dilated ResNet""" """Dilated ResNet"""
import math import math
import torch
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
#from .. import nn
import torch.nn as nn import torch.nn as nn
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
...@@ -234,7 +234,7 @@ def resnet34(pretrained=False, **kwargs): ...@@ -234,7 +234,7 @@ def resnet34(pretrained=False, **kwargs):
return model return model
def resnet50(pretrained=False, **kwargs): def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
Args: Args:
...@@ -242,11 +242,13 @@ def resnet50(pretrained=False, **kwargs): ...@@ -242,11 +242,13 @@ def resnet50(pretrained=False, **kwargs):
""" """
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet50', root=root)), strict=False)
return model return model
def resnet101(pretrained=False, **kwargs): def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.
Args: Args:
...@@ -254,11 +256,13 @@ def resnet101(pretrained=False, **kwargs): ...@@ -254,11 +256,13 @@ def resnet101(pretrained=False, **kwargs):
""" """
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet101', root=root)), strict=False)
return model return model
def resnet152(pretrained=False, **kwargs): def resnet152(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-152 model. """Constructs a ResNet-152 model.
Args: Args:
......
...@@ -10,9 +10,10 @@ ...@@ -10,9 +10,10 @@
"""Functions for Encoding Layer""" """Functions for Encoding Layer"""
import torch import torch
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
import torch.nn.functional as F
from .. import lib from .. import lib
__all__ = ['aggregate', 'scaledL2'] __all__ = ['aggregate', 'scaledL2', 'pairwise_cosine']
class _aggregate(Function): class _aggregate(Function):
@staticmethod @staticmethod
...@@ -93,3 +94,18 @@ def scaledL2(X, C, S): ...@@ -93,3 +94,18 @@ def scaledL2(X, C, S):
- Output: :math:`E\in\mathcal{R}^{B\times N\times K}` - Output: :math:`E\in\mathcal{R}^{B\times N\times K}`
""" """
return _scaledL2.apply(X, C, S) return _scaledL2.apply(X, C, S)
# Experimental
def pairwise_cosine(X, C, normalize=False):
r"""Pairwise Cosine Similarity or Dot-product Similarity
Shape:
- Input: :math:`X\in\mathcal{R}^{B\times N\times D}`
:math:`C\in\mathcal{R}^{K\times D}` :math:`S\in \mathcal{R}^K`
(where :math:`B` is batch, :math:`N` is total number of features,
:math:`K` is number is codewords, :math:`D` is feature dimensions.)
- Output: :math:`E\in\mathcal{R}^{B\times N\times K}`
"""
if normalize:
X = F.normalize(X, dim=2, eps=1e-8)
C = F.normalize(C, dim=1, eps=1e-8)
return torch.matmul(X, C.t())
from .model_zoo import get_model from .model_zoo import get_model
from .base import * from .base import *
from .fcn import * from .fcn import *
from .psp import *
from .encnet import * from .encnet import *
def get_segmentation_model(name, **kwargs): def get_segmentation_model(name, **kwargs):
from .fcn import get_fcn from .fcn import get_fcn
models = { models = {
'fcn': get_fcn, 'fcn': get_fcn,
'psp': get_psp,
'encnet': get_encnet, 'encnet': get_encnet,
} }
return models[name.lower()](**kwargs) return models[name.lower()](**kwargs)
...@@ -7,34 +7,34 @@ ...@@ -7,34 +7,34 @@
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
from torch.nn.functional import upsample import torch.nn.functional as F
import encoding import encoding
from .base import BaseNet from .base import BaseNet
from .fcn import FCNHead from .fcn import FCNHead
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext', __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
'get_encnet_resnet101_pcontext'] 'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade']
class EncNet(BaseNet): class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
norm_layer=nn.BatchNorm2d, **kwargs): norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss, self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
norm_layer=norm_layer, up_kwargs=self._up_kwargs) lateral=lateral, norm_layer=norm_layer,
up_kwargs=self._up_kwargs)
if aux: if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer=norm_layer) self.auxlayer = FCNHead(1024, nclass, norm_layer=norm_layer)
def forward(self, x): def forward(self, x):
imsize = x.size()[2:] imsize = x.size()[2:]
#features = self.base_forward(x) features = self.base_forward(x)
_, _, c3, c4 = self.base_forward(x)
x = list(self.head(c4)) x = list(self.head(*features))
x[0] = upsample(x[0], imsize, **self._up_kwargs) x[0] = F.upsample(x[0], imsize, **self._up_kwargs)
if self.aux: if self.aux:
auxout = self.auxlayer(c3) auxout = self.auxlayer(features[2])
auxout = upsample(auxout, imsize, **self._up_kwargs) auxout = F.upsample(auxout, imsize, **self._up_kwargs)
x.append(auxout) x.append(auxout)
return tuple(x) return tuple(x)
...@@ -42,16 +42,17 @@ class EncNet(BaseNet): ...@@ -42,16 +42,17 @@ class EncNet(BaseNet):
class EncModule(nn.Module): class EncModule(nn.Module):
def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None): def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None):
super(EncModule, self).__init__() super(EncModule, self).__init__()
if isinstance(norm_layer, encoding.nn.BatchNorm2d): norm_layer = nn.BatchNorm1d if isinstance(norm_layer, nn.BatchNorm2d) else \
norm_layer = encoding.nn.BatchNorm1d encoding.nn.BatchNorm1d
else:
norm_layer = nn.BatchNorm1d
self.se_loss = se_loss self.se_loss = se_loss
self.encoding = nn.Sequential( self.encoding = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
encoding.nn.Encoding(D=in_channels, K=ncodes), encoding.nn.Encoding(D=in_channels, K=ncodes),
norm_layer(ncodes), norm_layer(ncodes),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
encoding.nn.Sum(dim=1)) encoding.nn.Mean(dim=1))
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels), nn.Linear(in_channels, in_channels),
nn.Sigmoid()) nn.Sigmoid())
...@@ -63,31 +64,51 @@ class EncModule(nn.Module): ...@@ -63,31 +64,51 @@ class EncModule(nn.Module):
b, c, _, _ = x.size() b, c, _, _ = x.size()
gamma = self.fc(en) gamma = self.fc(en)
y = gamma.view(b, c, 1, 1) y = gamma.view(b, c, 1, 1)
# residual ? outputs = [F.relu_(x + x * y)]
outputs = [x + x * y]
if self.se_loss: if self.se_loss:
outputs.append(self.selayer(en)) outputs.append(self.selayer(en))
return tuple(outputs) return tuple(outputs)
class EncHead(nn.Module): class EncHead(nn.Module):
def __init__(self, out_channels, in_channels, se_loss=True, def __init__(self, out_channels, in_channels, se_loss=True, lateral=True,
norm_layer=None, up_kwargs=None): norm_layer=None, up_kwargs=None):
super(EncHead, self).__init__() super(EncHead, self).__init__()
self.se_loss = se_loss
self.lateral = lateral
self.up_kwargs = up_kwargs
self.conv5 = nn.Sequential( self.conv5 = nn.Sequential(
nn.Conv2d(in_channels, 512, 3, padding=1, bias=False), nn.Conv2d(in_channels, 512, 3, padding=1, bias=False),
norm_layer(512), norm_layer(512),
nn.ReLU(True)) nn.ReLU(inplace=True))
if lateral:
self.connect = nn.ModuleList([
nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
])
self.fusion = nn.Sequential(
nn.Conv2d(3*512, 512, kernel_size=3, padding=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True))
self.encmodule = EncModule(512, out_channels, ncodes=32, self.encmodule = EncModule(512, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer) se_loss=se_loss, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(0.1, False) self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False),
self.conv6 = nn.Conv2d(512, out_channels, 1) nn.Conv2d(512, out_channels, 1))
self.se_loss = se_loss
def forward(self, *inputs):
def forward(self, x): feat = self.conv5(inputs[-1])
x = self.conv5(x) if self.lateral:
outs = list(self.encmodule(x)) c2 = self.connect[0](inputs[1])
outs[0] = self.conv6(self.dropout(outs[0])) c3 = self.connect[1](inputs[2])
feat = self.fusion(torch.cat([feat, c2, c3], 1))
outs = list(self.encmodule(feat))
outs[0] = self.conv6(outs[0])
return tuple(outs) return tuple(outs)
...@@ -118,6 +139,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -118,6 +139,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
'ade20k': 'ade', 'ade20k': 'ade',
'pcontext': 'pcontext', 'pcontext': 'pcontext',
} }
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
# infer number of classes # infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
...@@ -144,7 +166,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', ** ...@@ -144,7 +166,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True) >>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet50', pretrained) return get_encnet('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -163,4 +185,23 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', * ...@@ -163,4 +185,23 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True) >>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet101', pretrained) return get_encnet('pcontext', 'resnet101', pretrained, aux=False, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet50', pretrained, aux=True, **kwargs)
...@@ -62,7 +62,7 @@ class FCNHead(nn.Module): ...@@ -62,7 +62,7 @@ class FCNHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer): def __init__(self, in_channels, out_channels, norm_layer):
super(FCNHead, self).__init__() super(FCNHead, self).__init__()
inter_channels = in_channels // 4 inter_channels = in_channels // 4
self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1), self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels), norm_layer(inter_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout2d(0.1, False), nn.Dropout2d(0.1, False),
...@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa ...@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True) >>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('pcontext', 'resnet50', pretrained, aux=False) return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_fcn_resnet50_ade(pretrained=True) >>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('ade20k', 'resnet50', pretrained) return get_fcn('ade20k', 'resnet50', pretrained, **kwargs)
...@@ -7,10 +7,15 @@ import zipfile ...@@ -7,10 +7,15 @@ import zipfile
from ..utils import download, check_sha1 from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [ _model_sha1 = {name: checksum for checksum, name in [
('853f2fb07aeb2927f7696e166b215609a987fd44', 'resnet50'),
#('bbba8e79b6bd131e82e2edf2ac0f119b3c6b8f87', 'resnet50'),
('5be5422ad7cb6a2e5f5a54070d0aa9affe69a9a4', 'resnet101'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('969062a5aad2d1d983bae2f9e412578b62610114', 'encnet_resnet50_pcontext'), ('425a7b15176105be0c0ae522aefde02bdcb3b9f5', 'encnet_resnet50_pcontext'),
('3062cec955670690d3481d75e7e6368c721a46ce', 'encnet_resnet101_pcontext'), ('abf1472fde53b7b41d7801a1f715765e1ef6f86e', 'encnet_resnet101_pcontext'),
('167f05f69df94d4066dad155d1a71dc6493747eb', 'encnet_resnet50_ade'),
('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'), ('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
]} ]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
......
# pylint: disable=wildcard-import, unused-wildcard-import # pylint: disable=wildcard-import, unused-wildcard-import
from .fcn import * from .fcn import *
from .psp import *
from .encnet import * from .encnet import *
__all__ = ['get_model'] __all__ = ['get_model']
...@@ -27,7 +28,9 @@ def get_model(name, **kwargs): ...@@ -27,7 +28,9 @@ def get_model(name, **kwargs):
'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext, 'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext, 'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext, 'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
'encnet_resnet50_ade': get_encnet_resnet50_ade,
'fcn_resnet50_ade': get_fcn_resnet50_ade, 'fcn_resnet50_ade': get_fcn_resnet50_ade,
'psp_resnet50_ade': get_psp_resnet50_ade,
} }
name = name.lower() name = name.lower()
if name not in models: if name not in models:
......
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from .base import BaseNet
from .fcn import FCNHead
from ..nn import PyramidPooling
class PSP(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
def forward(self, x):
_, _, h, w = x.size()
_, _, c3, c4 = self.base_forward(x)
outputs = []
x = self.head(c4)
x = upsample(x, (h,w), **self._up_kwargs)
outputs.append(x)
if self.aux:
auxout = self.auxlayer(c3)
auxout = upsample(auxout, (h,w), **self._up_kwargs)
outputs.append(auxout)
return tuple(outputs)
class PSPHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs):
super(PSPHead, self).__init__()
inter_channels = in_channels // 4
self.conv5 = nn.Sequential(PyramidPooling(in_channels, norm_layer, up_kwargs),
nn.Conv2d(in_channels * 2, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU(True),
nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
return self.conv5(x)
def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('psp_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_psp('ade20k', 'resnet50', pretrained)
...@@ -14,12 +14,11 @@ from torch.nn import Module, Sequential, Conv2d, ReLU, AdaptiveAvgPool2d, \ ...@@ -14,12 +14,11 @@ from torch.nn import Module, Sequential, Conv2d, ReLU, AdaptiveAvgPool2d, \
NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter
from torch.nn import functional as F from torch.nn import functional as F
from torch.autograd import Variable from torch.autograd import Variable
from .syncbn import BatchNorm2d
torch_ver = torch.__version__[:3] torch_ver = torch.__version__[:3]
__all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean', __all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean',
'Normalize'] 'Normalize', 'PyramidPooling']
class GramMatrix(Module): class GramMatrix(Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
...@@ -147,4 +146,48 @@ class Normalize(Module): ...@@ -147,4 +146,48 @@ class Normalize(Module):
self.dim = dim self.dim = dim
def forward(self, x): def forward(self, x):
return F.normalize(x, self.p, self.dim, eps=1e-10) return F.normalize(x, self.p, self.dim, eps=1e-8)
class PyramidPooling(Module):
"""
Reference:
Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
"""
def __init__(self, in_channels, norm_layer, up_kwargs):
super(PyramidPooling, self).__init__()
self.pool1 = AdaptiveAvgPool2d(1)
self.pool2 = AdaptiveAvgPool2d(2)
self.pool3 = AdaptiveAvgPool2d(3)
self.pool4 = AdaptiveAvgPool2d(6)
out_channels = int(in_channels/4)
self.conv1 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv2 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv3 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv4 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
# bilinear upsample options
self._up_kwargs = up_kwargs
def _cat_each(self, x, feat1, feat2, feat3, feat4):
assert(len(x) == len(feat1))
z = []
for i in range(len(x)):
z.append(torch.cat((x[i], feat1[i], feat2[i], feat3[i], feat4[i]), 1))
return z
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
feat2 = F.upsample(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs)
feat3 = F.upsample(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs)
feat4 = F.upsample(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs)
return torch.cat((x, feat1, feat2, feat3, feat4), 1)
...@@ -15,7 +15,7 @@ import torch.nn.functional as F ...@@ -15,7 +15,7 @@ import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..functions import scaledL2, aggregate from ..functions import scaledL2, aggregate, pairwise_cosine
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d'] __all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d']
...@@ -111,6 +111,7 @@ class Encoding(Module): ...@@ -111,6 +111,7 @@ class Encoding(Module):
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \ + 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
+ str(self.D) + ')' + str(self.D) + ')'
class EncodingDrop(Module): class EncodingDrop(Module):
r"""Dropout regularized Encoding Layer. r"""Dropout regularized Encoding Layer.
""" """
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
from .lr_scheduler import LR_Scheduler from .lr_scheduler import LR_Scheduler
from .metrics import batch_intersection_union, batch_pix_accuracy from .metrics import batch_intersection_union, batch_pix_accuracy
from .pallete import get_mask_pallete from .pallete import get_mask_pallete
from .train_helper import get_selabel_vector from .train_helper import get_selabel_vector, EMA
from .presets import load_image from .presets import load_image
from .files import * from .files import *
__all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union', __all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union',
'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image', 'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image',
'get_mask_pallete'] 'get_mask_pallete', 'get_selabel_vector', 'EMA']
...@@ -24,34 +24,39 @@ class LR_Scheduler(object): ...@@ -24,34 +24,39 @@ class LR_Scheduler(object):
:attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
:attr:`args.lr_step` :attr:`args.lr_step`
niters: number of iterations per epoch iters_per_epoch: number of iterations per epoch
""" """
def __init__(self, args, niters=0): def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
self.mode = args.lr_scheduler lr_step=0, warmup_epochs=0):
self.mode = mode
print('Using {} LR Scheduler!'.format(self.mode)) print('Using {} LR Scheduler!'.format(self.mode))
self.lr = args.lr self.lr = base_lr
if self.mode == 'step': if mode == 'step':
self.lr_step = args.lr_step assert lr_step
else: self.lr_step = lr_step
self.niters = niters self.iters_per_epoch = iters_per_epoch
self.N = args.epochs * niters self.N = num_epochs * iters_per_epoch
self.epoch = -1 self.epoch = -1
self.warmup_iters = warmup_epochs * iters_per_epoch
def __call__(self, optimizer, i, epoch, best_pred): def __call__(self, optimizer, i, epoch, best_pred):
T = epoch * self.iters_per_epoch + i
if self.mode == 'cos': if self.mode == 'cos':
T = (epoch - 1) * self.niters + i
lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi))
elif self.mode == 'poly': elif self.mode == 'poly':
T = (epoch - 1) * self.niters + i
lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9)
elif self.mode == 'step': elif self.mode == 'step':
lr = self.lr * (0.1 ** ((epoch - 1) // self.lr_step)) lr = self.lr * (0.1 ** (epoch // self.lr_step))
else: else:
raise RuntimeError('Unknown LR scheduler!') raise NotImplemented
# warm up lr schedule
if self.warmup_iters > 0 and T < self.warmup_iters:
lr = lr * 1.0 * T / self.warmup_iters
if epoch > self.epoch: if epoch > self.epoch:
print('\n=>Epoches %i, learning rate = %.4f, \ print('\n=>Epoches %i, learning rate = %.4f, \
previous best = %.4f' % (epoch, lr, best_pred)) previous best = %.4f' % (epoch, lr, best_pred))
self.epoch = epoch self.epoch = epoch
assert lr >= 0
self._adjust_learning_rate(optimizer, lr) self._adjust_learning_rate(optimizer, lr)
def _adjust_learning_rate(self, optimizer, lr): def _adjust_learning_rate(self, optimizer, lr):
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
import torch import torch
def get_selabel_vector(target, nclass): def get_selabel_vector(target, nclass):
"""Get SE-Loss Label in a batch r"""Get SE-Loss Label in a batch
Args: Args:
predict: input 4D tensor predict: input 4D tensor
target: label 3D tensor (BxHxW) target: label 3D tensor (BxHxW)
...@@ -29,3 +29,44 @@ def get_selabel_vector(target, nclass): ...@@ -29,3 +29,44 @@ def get_selabel_vector(target, nclass):
tvect[i] = vect tvect[i] = vect
return tvect return tvect
class EMA():
r""" Use moving avg for the models.
Examples:
>>> ema = EMA(0.999)
>>> for name, param in model.named_parameters():
>>> if param.requires_grad:
>>> ema.register(name, param.data)
>>>
>>> # during training:
>>> # optimizer.step()
>>> for name, param in model.named_parameters():
>>> # Sometime I also use the moving average of non-trainable parameters, just according to the model structure
>>> if param.requires_grad:
>>> ema(name, param.data)
>>>
>>> # during eval or test
>>> import copy
>>> model_test = copy.deepcopy(model)
>>> for name, param in model_test.named_parameters():
>>> # Sometime I also use the moving average of non-trainable parameters, just according to the model structure
>>> if param.requires_grad:
>>> param.data = ema.get(name)
>>> # Then use model_test for eval.
"""
def __init__(self, momentum):
self.momentum = momentum
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.clone()
def __call__(self, name, x):
assert name in self.shadow
new_average = (1.0 - self.momentum) * x + self.momentum * self.shadow[name]
self.shadow[name] = new_average.clone()
return new_average
def get(self, name):
assert name in self.shadow
return self.shadow[name]
...@@ -76,7 +76,8 @@ def main(): ...@@ -76,7 +76,8 @@ def main():
else: else:
raise RuntimeError ("=> no resume checkpoint found at '{}'".\ raise RuntimeError ("=> no resume checkpoint found at '{}'".\
format(args.resume)) format(args.resume))
scheduler = LR_Scheduler(args, len(train_loader)) scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
len(train_loader), args.lr_step)
def train(epoch): def train(epoch):
model.train() model.train()
global best_pred, errlist_train global best_pred, errlist_train
......
...@@ -23,7 +23,7 @@ class Options(): ...@@ -23,7 +23,7 @@ class Options():
default=os.path.join(os.environ['HOME'], 'data'), default=os.path.join(os.environ['HOME'], 'data'),
help='training dataset folder (default: \ help='training dataset folder (default: \
$(HOME)/data)') $(HOME)/data)')
parser.add_argument('--workers', type=int, default=4, parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads') metavar='N', help='dataloader threads')
# training hyper params # training hyper params
parser.add_argument('--aux', action='store_true', default= False, parser.add_argument('--aux', action='store_true', default= False,
...@@ -37,12 +37,12 @@ class Options(): ...@@ -37,12 +37,12 @@ class Options():
parser.add_argument('--batch-size', type=int, default=None, parser.add_argument('--batch-size', type=int, default=None,
metavar='N', help='input batch size for \ metavar='N', help='input batch size for \
training (default: auto)') training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=16, parser.add_argument('--test-batch-size', type=int, default=None,
metavar='N', help='input batch size for \ metavar='N', help='input batch size for \
testing (default: 32)') testing (default: same as batch size)')
parser.add_argument('--lr', type=float, default=None, metavar='LR', parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: auto)') help='learning rate (default: auto)')
parser.add_argument('--lr-scheduler', type=str, default='poly', parser.add_argument('--lr-scheduler', type=str, default='poly',
help='learning rate scheduler (default: poly)') help='learning rate scheduler (default: poly)')
parser.add_argument('--momentum', type=float, default=0.9, parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='momentum (default: 0.9)') metavar='M', help='momentum (default: 0.9)')
...@@ -67,6 +67,8 @@ class Options(): ...@@ -67,6 +67,8 @@ class Options():
help='num of pre-trained classes \ help='num of pre-trained classes \
(default: None)') (default: None)')
# evaluation option # evaluation option
parser.add_argument('--ema', action='store_true', default= False,
help='using EMA evaluation')
parser.add_argument('--eval', action='store_true', default= False, parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU') help='evaluating mIoU')
parser.add_argument('--no-val', action='store_true', default= False, parser.add_argument('--no-val', action='store_true', default= False,
...@@ -85,12 +87,14 @@ class Options(): ...@@ -85,12 +87,14 @@ class Options():
epoches = { epoches = {
'pascal_voc': 50, 'pascal_voc': 50,
'pascal_aug': 50, 'pascal_aug': 50,
'pcontext': 50, 'pcontext': 80,
'ade20k': 120, 'ade20k': 120,
} }
args.epochs = epoches[args.dataset.lower()] args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None: if args.batch_size is None:
args.batch_size = 4 * torch.cuda.device_count() args.batch_size = 4 * torch.cuda.device_count()
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.lr is None: if args.lr is None:
lrs = { lrs = {
'pascal_voc': 0.0001, 'pascal_voc': 0.0001,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment