"docs/source/api/vscode:/vscode.git/clone" did not exist on "308bd6f5b245929211f365396ca2007ac151b8e7"
Unverified Commit 07f25381 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files
parents cebf1341 70fdeb79
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCDeviceTensorUtils.cu"
#else
/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will
/// error if the dimensionality does not match exactly.
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t);
template <typename T, int Dim, typename IndexT>
THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t);
}
template <typename T, int Dim>
THCDeviceTensor<T, Dim, int, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t);
}
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
if (Dim != THCTensor_(nDimension)(state, t)) {
THError("THCudaTensor dimension mismatch");
}
// Determine the maximum offset into the tensor achievable; `IndexT`
// must be smaller than this type in order to use it.
ptrdiff_t maxOffset = 0;
IndexT sizes[Dim];
IndexT strides[Dim];
for (int i = 0; i < Dim; ++i) {
int64_t size = THCTensor_(size)(state, t, i);
int64_t stride = THCTensor_(stride)(state, t, i);
maxOffset += (size - 1) * stride;
sizes[i] = (IndexT) size;
strides[i] = (IndexT) stride;
}
if (maxOffset > std::numeric_limits<IndexT>::max()) {
THError("THCudaTensor sizes too large for THCDeviceTensor conversion");
}
return THCDeviceTensor<T, Dim, IndexT, PtrTraits>(
THCTensor_(data)(state, t), sizes, strides);
}
#endif
#ifndef THC_DEVICE_TENSOR_UTILS_INC
#define THC_DEVICE_TENSOR_UTILS_INC
#include "THCDeviceTensor.cuh"
#include "THCTensor.h"
#include <limits>
/// Constructs a DeviceTensor initialized from a THCudaTensor by
/// upcasting or downcasting the tensor to that of a different
/// dimension.
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensorCast(THCState* state, THCudaTensor* t);
template <typename T, int Dim, typename IndexT>
THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>
toDeviceTensorCast(THCState* state, THCudaTensor* t) {
return toDeviceTensorCast<T, Dim, IndexT, DefaultPtrTraits>(state, t);
}
template <typename T, int Dim>
THCDeviceTensor<T, Dim, int, DefaultPtrTraits>
toDeviceTensorCast(THCState* state, THCudaTensor* t) {
return toDeviceTensorCast<T, Dim, int, DefaultPtrTraits>(state, t);
}
#include "generic/THCDeviceTensorUtils.cu"
#include "THCGenerateAllTypes.h"
#include "THCDeviceTensorUtils-inl.cuh"
#endif // THC_DEVICE_TENSOR_UTILS_INC
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCDeviceTensorUtils.cu"
#else
/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will
/// error if the dimensionality does not match exactly.
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t);
template <typename T, int Dim, typename IndexT>
THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t);
}
template <typename T, int Dim>
THCDeviceTensor<T, Dim, int, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t);
}
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
if (Dim != THCTensor_(nDimension)(state, t)) {
THError("THCudaTensor dimension mismatch");
}
// Determine the maximum offset into the tensor achievable; `IndexT`
// must be smaller than this type in order to use it.
ptrdiff_t maxOffset = 0;
IndexT sizes[Dim];
IndexT strides[Dim];
for (int i = 0; i < Dim; ++i) {
int64_t size = THCTensor_(size)(state, t, i);
int64_t stride = THCTensor_(stride)(state, t, i);
maxOffset += (size - 1) * stride;
sizes[i] = (IndexT) size;
strides[i] = (IndexT) stride;
}
if (maxOffset > std::numeric_limits<IndexT>::max()) {
THError("THCudaTensor sizes too large for THCDeviceTensor conversion");
}
return THCDeviceTensor<T, Dim, IndexT, PtrTraits>(
THCTensor_(data)(state, t), sizes, strides);
}
#endif
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
* Created by: Hang Zhang
* ECE Department, Rutgers University
* Email: zhang.hang@rutgers.edu
* Copyright (c) 2017
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#include "thc_encoding.h"
#include "common.h"
#include "generic/device_tensor.h"
#include "THC/THCGenerateFloatType.h"
#include "generic/device_tensor.h"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus
extern "C" {
#endif
// float
#include "generic/encoding_utils.c"
#include "THC/THCGenerateFloatType.h"
#include "generic/encoding_kernel.c"
#include "THC/THCGenerateFloatType.h"
#include "generic/syncbn_kernel.c"
#include "THC/THCGenerateFloatType.h"
#include "generic/pooling_kernel.c"
#include "THC/THCGenerateFloatType.h"
// double
#include "generic/encoding_utils.c"
#include "THC/THCGenerateDoubleType.h"
#include "generic/encoding_kernel.c"
#include "THC/THCGenerateDoubleType.h"
#include "generic/syncbn_kernel.c"
#include "THC/THCGenerateDoubleType.h"
#include "generic/pooling_kernel.c"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus
}
#endif
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
* Created by: Hang Zhang
* ECE Department, Rutgers University
* Email: zhang.hang@rutgers.edu
* Copyright (c) 2017
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#include <THC.h>
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;
#define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME)
#define THCTensor TH_CONCAT_3(TH,CReal,Tensor)
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
#ifdef __cplusplus
extern "C" {
#endif
// float
#include "generic/encoding_kernel.h"
#include "THC/THCGenerateFloatType.h"
#include "generic/syncbn_kernel.h"
#include "THC/THCGenerateFloatType.h"
#include "generic/pooling_kernel.h"
#include "THC/THCGenerateFloatType.h"
// double
#include "generic/encoding_kernel.h"
#include "THC/THCGenerateDoubleType.h"
#include "generic/syncbn_kernel.h"
#include "THC/THCGenerateDoubleType.h"
#include "generic/pooling_kernel.h"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus
}
#endif
#!/usr/bin/env bash
mkdir -p encoding/lib && cd encoding/lib
# compile and install
cmake ..
make
from .model_zoo import get_model
from .base import *
from .fcn import *
from .encnet import *
def get_segmentation_model(name, **kwargs):
from .fcn import get_fcn
models = {
'fcn': get_fcn,
'encnet': get_encnet,
}
return models[name.lower()](**kwargs)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import upsample
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.scatter_gather import scatter
from .. import dilated as resnet
from ..utils import batch_pix_accuracy, batch_intersection_union
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
__all__ = ['BaseNet', 'EvalModule', 'MultiEvalModule']
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
mean=[.485, .456, .406], std=[.229, .224, .225]):
super(BaseNet, self).__init__()
self.nclass = nclass
self.aux = aux
self.se_loss = se_loss
self.mean = mean
self.std = std
# copying modules from pretrained models
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer)
elif backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer)
elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options
self._up_kwargs = up_kwargs
def base_forward(self, x):
x = self.pretrained.conv1(x)
x = self.pretrained.bn1(x)
x = self.pretrained.relu(x)
x = self.pretrained.maxpool(x)
c1 = self.pretrained.layer1(x)
c2 = self.pretrained.layer2(c1)
c3 = self.pretrained.layer3(c2)
c4 = self.pretrained.layer4(c3)
return c1, c2, c3, c4
def evaluate(self, x, target=None):
pred = self.forward(x)
if isinstance(pred, (tuple, list)):
pred = pred[0]
if target is None:
return pred
correct, labeled = batch_pix_accuracy(pred.data, target.data)
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
return correct, labeled, inter, union
class EvalModule(nn.Module):
"""Segmentation Eval Module"""
def __init__(self, module):
super(EvalModule, self).__init__()
self.module = module
def forward(self, *inputs, **kwargs):
return self.module.evaluate(*inputs, **kwargs)
class MultiEvalModule(DataParallel):
"""Multi-size Segmentation Eavluator"""
def __init__(self, module, nclass, device_ids=None,
base_size=520, crop_size=480, flip=True,
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
super(MultiEvalModule, self).__init__(module, device_ids)
self.nclass = nclass
self.base_size = base_size
self.crop_size = crop_size
self.scales = scales
self.flip = flip
def parallel_forward(self, inputs, **kwargs):
"""Multi-GPU Mult-size Evaluation
Args:
inputs: list of Tensors
"""
inputs = [(input.unsqueeze(0).cuda(device),) for input, device in zip(inputs, self.device_ids)]
replicas = self.replicate(self, self.device_ids[:len(inputs)])
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return outputs
def forward(self, image):
"""Mult-size Evaluation"""
# only single image is supported for evaluation
batch, _, h, w = image.size()
assert(batch == 1)
stride_rate = 2.0/3.0
crop_size = self.crop_size
stride = int(crop_size * stride_rate)
with torch.cuda.device_of(image):
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
for scale in self.scales:
long_size = int(math.ceil(self.base_size * scale))
if h > w:
height = long_size
width = int(1.0 * w * long_size / h + 0.5)
short_size = width
else:
width = long_size
height = int(1.0 * h * long_size / w + 0.5)
short_size = height
# resize image to current size
cur_img = resize_image(image, height, width)
if scale <= 1.25 or long_size <= crop_size:# #
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
outputs = self.module_inference(pad_img)
outputs = crop_image(outputs, 0, height, 0, width)
else:
if short_size < crop_size:
# pad if needed
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
else:
pad_img = cur_img
_,_,ph,pw = pad_img.size()
assert(ph >= height and pw >= width)
# grid forward and normalize
h_grids = int(math.ceil(1.0*(ph-crop_size)/stride)) + 1
w_grids = int(math.ceil(1.0*(pw-crop_size)/stride)) + 1
with torch.cuda.device_of(image):
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
# grid evaluation
for idh in range(h_grids):
for idw in range(w_grids):
h0 = idh * stride
w0 = idw * stride
h1 = min(h0 + crop_size, ph)
w1 = min(w0 + crop_size, pw)
crop_img = crop_image(pad_img, h0, h1, w0, w1)
# pad if needed
pad_crop_img = pad_image(crop_img, self.module.mean,
self.module.std, crop_size)
output = self.module_inference(pad_crop_img)
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
0, h1-h0, 0, w1-w0)
count_norm[:,:,h0:h1,w0:w1] += 1
assert((count_norm==0).sum()==0)
outputs = outputs / count_norm
outputs = outputs[:,:,:height,:width]
score = resize_image(outputs, h, w)
scores += score
return scores
def module_inference(self, image):
output = self.module.evaluate(image)
if self.flip:
fimg = flip_image(image)
foutput = self.module.evaluate(fimg)
output += flip_image(foutput)
return output.exp()
def resize_image(img, h, w, mode='bilinear'):
return F.upsample(img, (h, w), **up_kwargs)
def pad_image(img, mean, std, crop_size):
b,c,h,w = img.size()
assert(c==3)
padh = crop_size - h if h < crop_size else 0
padw = crop_size - w if w < crop_size else 0
pad_values = -np.array(mean) / np.array(std)
img_pad = img.new().resize_(b,c,h+padh,w+padw)
#img_pad = F.pad(img, (0,padw,0,padh))
for i in range(c):
# note that pytorch pad params is in reversed orders
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh),
value=pad_values[i])
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
return img_pad
def crop_image(img, h0, h1, w0, w1):
return img[:,:,h0:h1,w0:w1]
def flip_image(img):
assert(img.dim()==4)
with torch.cuda.device_of(img):
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
return img.index_select(3, idx)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.nn.functional import upsample
import encoding
from .base import BaseNet
from .fcn import FCNHead
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext']
class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True,
norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
norm_layer=norm_layer, up_kwargs=self._up_kwargs)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer=norm_layer)
def forward(self, x):
imsize = x.size()[2:]
#features = self.base_forward(x)
_, _, c3, c4 = self.base_forward(x)
x = list(self.head(c4))
x[0] = upsample(x[0], imsize, **self._up_kwargs)
if self.aux:
auxout = self.auxlayer(c3)
auxout = upsample(auxout, imsize, **self._up_kwargs)
x.append(auxout)
return tuple(x)
class EncModule(nn.Module):
def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None):
super(EncModule, self).__init__()
if isinstance(norm_layer, encoding.nn.BatchNorm2d):
norm_layer = encoding.nn.BatchNorm1d
else:
norm_layer = nn.BatchNorm1d
self.se_loss = se_loss
self.encoding = nn.Sequential(
encoding.nn.Encoding(D=in_channels, K=ncodes),
norm_layer(ncodes),
nn.ReLU(inplace=True),
encoding.nn.Sum(dim=1))
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels),
nn.Sigmoid())
if self.se_loss:
self.selayer = nn.Linear(in_channels, nclass)
def forward(self, x):
en = self.encoding(x)
b, c, _, _ = x.size()
gamma = self.fc(en)
y = gamma.view(b, c, 1, 1)
# residual ?
outputs = [x + x * y]
if self.se_loss:
outputs.append(self.selayer(en))
return tuple(outputs)
class EncHead(nn.Module):
def __init__(self, out_channels, in_channels, se_loss=True,
norm_layer=None, up_kwargs=None):
super(EncHead, self).__init__()
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels, 512, 3, padding=1, bias=False),
norm_layer(512),
nn.ReLU(True))
self.encmodule = EncModule(512, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(0.1, False)
self.conv6 = nn.Conv2d(512, out_channels, 1)
self.se_loss = se_loss
def forward(self, x):
x = self.conv5(x)
outs = list(self.encmodule(x))
outs[0] = self.conv6(self.dropout(outs[0]))
return tuple(outs)
def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
backbone : str, default resnet50
The backbone network. (resnet50, 101, 152)
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'ade20k': 'ade',
'pcontext': 'pcontext',
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('encnet_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet50', pretrained)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from .base import BaseNet
__all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_pcontext', 'get_fcn_resnet50_ade']
class FCN(BaseNet):
r"""Fully Convolutional Networks for Semantic Segmentation
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
Reference:
Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks
for semantic segmentation." *CVPR*, 2015
Examples
--------
>>> model = FCN(nclass=21, backbone='resnet50')
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
self.head = FCNHead(2048, nclass, norm_layer)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
def forward(self, x):
imsize = x.size()[2:]
_, _, c3, c4 = self.base_forward(x)
x = self.head(c4)
x = upsample(x, imsize, **self._up_kwargs)
outputs = [x]
if self.aux:
auxout = self.auxlayer(c3)
auxout = upsample(auxout, imsize, **self._up_kwargs)
outputs.append(auxout)
return tuple(outputs)
class FCNHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
super(FCNHead, self).__init__()
inter_channels = in_channels // 4
self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1),
norm_layer(inter_channels),
nn.ReLU(),
nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
return self.conv5(x)
def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""FCN model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'pcontext': 'pcontext',
'ade20k': 'ade',
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('fcn_%s_%s'%(backbone, acronyms[dataset]), root=root)),
strict= False)
return model
def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_fcn('pcontext', 'resnet50', pretrained)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_fcn('ade20k', 'resnet50', pretrained)
This diff is collapsed.
# pylint: disable=wildcard-import, unused-wildcard-import
from .fcn import *
from .encnet import *
__all__ = ['get_model']
def get_model(name, **kwargs):
"""Returns a pre-defined model by name
Parameters
----------
name : str
Name of the model.
pretrained : bool
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Returns
-------
Module:
The model.
"""
models = {
'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'fcn_resnet50_ade': get_fcn_resnet50_ade,
}
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(e), '\n\t'.join(sorted(models.keys()))))
net = models[name](**kwargs)
return net
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment