Unverified Commit 16363650 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

fix path (#73)

parent 2dd88e58
......@@ -24,7 +24,7 @@ __all__ = ['BaseNet', 'MultiEvalModule']
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
mean=[.485, .456, .406], std=[.229, .224, .225]):
mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'):
super(BaseNet, self).__init__()
self.nclass = nclass
self.aux = aux
......@@ -33,11 +33,14 @@ class BaseNet(nn.Module):
self.std = std
# copying modules from pretrained models
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options
......
......@@ -19,7 +19,8 @@ __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(EncNet, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
lateral=lateral, norm_layer=norm_layer,
up_kwargs=self._up_kwargs)
......@@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
......
......@@ -39,7 +39,7 @@ class FCN(BaseNet):
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = FCNHead(2048, nclass, norm_layer)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
......@@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
......@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_fcn('ade20k', 'resnet50', pretrained, **kwargs)
return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs)
......@@ -16,7 +16,7 @@ from ..nn import PyramidPooling
class PSP(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
......@@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
......@@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_psp('ade20k', 'resnet50', pretrained)
return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment