Unverified Commit 16363650 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

fix path (#73)

parent 2dd88e58
...@@ -24,7 +24,7 @@ __all__ = ['BaseNet', 'MultiEvalModule'] ...@@ -24,7 +24,7 @@ __all__ = ['BaseNet', 'MultiEvalModule']
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None, def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
mean=[.485, .456, .406], std=[.229, .224, .225]): mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'):
super(BaseNet, self).__init__() super(BaseNet, self).__init__()
self.nclass = nclass self.nclass = nclass
self.aux = aux self.aux = aux
...@@ -33,11 +33,14 @@ class BaseNet(nn.Module): ...@@ -33,11 +33,14 @@ class BaseNet(nn.Module):
self.std = std self.std = std
# copying modules from pretrained models # copying modules from pretrained models
if backbone == 'resnet50': if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer) self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet101': elif backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer) self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet152': elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer) self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
else: else:
raise RuntimeError('unknown backbone: {}'.format(backbone)) raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options # bilinear upsample options
......
...@@ -19,7 +19,8 @@ __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext', ...@@ -19,7 +19,8 @@ __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
class EncNet(BaseNet): class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False, def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
norm_layer=nn.BatchNorm2d, **kwargs): norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) super(EncNet, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss, self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
lateral=lateral, norm_layer=norm_layer, lateral=lateral, norm_layer=norm_layer,
up_kwargs=self._up_kwargs) up_kwargs=self._up_kwargs)
...@@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
# infer number of classes # infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
......
...@@ -39,7 +39,7 @@ class FCN(BaseNet): ...@@ -39,7 +39,7 @@ class FCN(BaseNet):
>>> print(model) >>> print(model)
""" """
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = FCNHead(2048, nclass, norm_layer) self.head = FCNHead(2048, nclass, norm_layer)
if aux: if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer) self.auxlayer = FCNHead(1024, nclass, norm_layer)
...@@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
} }
# infer number of classes # infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
...@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa ...@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True) >>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs) return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_fcn_resnet50_ade(pretrained=True) >>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('ade20k', 'resnet50', pretrained, **kwargs) return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs)
...@@ -16,7 +16,7 @@ from ..nn import PyramidPooling ...@@ -16,7 +16,7 @@ from ..nn import PyramidPooling
class PSP(BaseNet): class PSP(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs) self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs)
if aux: if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer) self.auxlayer = FCNHead(1024, nclass, norm_layer)
...@@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
} }
# infer number of classes # infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
...@@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_psp_resnet50_ade(pretrained=True) >>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_psp('ade20k', 'resnet50', pretrained) return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment