Unverified Commit 2dd88e58 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

fix path (#72)

parent 32e382bc
...@@ -166,7 +166,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', ** ...@@ -166,7 +166,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True) >>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet50', pretrained, aux=False, **kwargs) return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -185,7 +185,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', * ...@@ -185,7 +185,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True) >>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet101', pretrained, aux=False, **kwargs) return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=False, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -204,4 +204,4 @@ def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwarg ...@@ -204,4 +204,4 @@ def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwarg
>>> model = get_encnet_resnet50_ade(pretrained=True) >>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('ade20k', 'resnet50', pretrained, aux=True, **kwargs) return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True, **kwargs)
...@@ -177,13 +177,6 @@ class PyramidPooling(Module): ...@@ -177,13 +177,6 @@ class PyramidPooling(Module):
# bilinear upsample options # bilinear upsample options
self._up_kwargs = up_kwargs self._up_kwargs = up_kwargs
def _cat_each(self, x, feat1, feat2, feat3, feat4):
assert(len(x) == len(feat1))
z = []
for i in range(len(x)):
z.append(torch.cat((x[i], feat1[i], feat2[i], feat3[i], feat4[i]), 1))
return z
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs) feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment