Unverified Commit 9bc70531 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

add detail API and other fixes (#63)

parent 3ba8d2f7
......@@ -6,7 +6,7 @@ created by [Hang Zhang](http://hangzh.com/)
- Please visit the [**Docs**](http://hangzh.com/PyTorch-Encoding/) for detail instructions of installation and usage.
- How to use Synchronized Batch Normalization (SyncBN)? See the [examples](https://github.com/zhanghang1989/PyTorch-SyncBatchNorm).
- Please visit the [link](http://hangzh.com/PyTorch-Encoding/experiments/segmentation.html) to examples of semantic segmentation.
## Citations
......
......@@ -106,7 +106,4 @@ def test_batchify_fn(data):
elif isinstance(data[0], (tuple, list)):
data = zip(*data)
return [test_batchify_fn(i) for i in data]
elif isinstance(data[0], ):
data = np.asarray(data)
return mx.nd.array(data, dtype=data.dtype)
raise TypeError((error_msg.format(type(batch[0]))))
......@@ -20,7 +20,7 @@ from ..utils import batch_pix_accuracy, batch_intersection_union
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
__all__ = ['BaseNet', 'EvalModule', 'MultiEvalModule']
__all__ = ['BaseNet', 'MultiEvalModule']
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
......@@ -65,16 +65,6 @@ class BaseNet(nn.Module):
return correct, labeled, inter, union
class EvalModule(nn.Module):
"""Segmentation Eval Module"""
def __init__(self, module):
super(EvalModule, self).__init__()
self.module = module
def forward(self, *inputs, **kwargs):
return self.module.evaluate(*inputs, **kwargs)
class MultiEvalModule(DataParallel):
"""Multi-size Segmentation Eavluator"""
def __init__(self, module, nclass, device_ids=None,
......@@ -125,11 +115,11 @@ class MultiEvalModule(DataParallel):
height = int(1.0 * h * long_size / w + 0.5)
short_size = height
# resize image to current size
cur_img = resize_image(image, height, width)
if scale <= 1.25 or long_size <= crop_size:# #
cur_img = resize_image(image, height, width, **self.module._up_kwargs)
if long_size <= crop_size:
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
outputs = self.module_inference(pad_img)
outputs = module_inference(self.module, pad_img, self.flip)
outputs = crop_image(outputs, 0, height, 0, width)
else:
if short_size < crop_size:
......@@ -157,7 +147,7 @@ class MultiEvalModule(DataParallel):
# pad if needed
pad_crop_img = pad_image(crop_img, self.module.mean,
self.module.std, crop_size)
output = self.module_inference(pad_crop_img)
output = module_inference(self.module, pad_crop_img, self.flip)
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
0, h1-h0, 0, w1-w0)
count_norm[:,:,h0:h1,w0:w1] += 1
......@@ -165,21 +155,21 @@ class MultiEvalModule(DataParallel):
outputs = outputs / count_norm
outputs = outputs[:,:,:height,:width]
score = resize_image(outputs, h, w)
score = resize_image(outputs, h, w, **self.module._up_kwargs)
scores += score
return scores
def module_inference(self, image):
output = self.module.evaluate(image)
if self.flip:
fimg = flip_image(image)
foutput = self.module.evaluate(fimg)
output += flip_image(foutput)
return output.exp()
def module_inference(module, image, flip=True):
output = module.evaluate(image)
if flip:
fimg = flip_image(image)
foutput = module.evaluate(fimg)
output += flip_image(foutput)
return output.exp()
def resize_image(img, h, w, mode='bilinear'):
def resize_image(img, h, w, **up_kwargs):
return F.upsample(img, (h, w), **up_kwargs)
def pad_image(img, mean, std, crop_size):
......@@ -189,11 +179,9 @@ def pad_image(img, mean, std, crop_size):
padw = crop_size - w if w < crop_size else 0
pad_values = -np.array(mean) / np.array(std)
img_pad = img.new().resize_(b,c,h+padh,w+padw)
#img_pad = F.pad(img, (0,padw,0,padh))
for i in range(c):
# note that pytorch pad params is in reversed orders
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh),
value=pad_values[i])
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
return img_pad
......
......@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_fcn('pcontext', 'resnet50', pretrained)
return get_fcn('pcontext', 'resnet50', pretrained, aux=False)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......
......@@ -21,7 +21,6 @@ torch_ver = torch.__version__[:3]
__all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean',
'Normalize']
class GramMatrix(Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
......
......@@ -45,8 +45,8 @@ def main():
torch.cuda.manual_seed(args.seed)
# init dataloader
dataset = importlib.import_module('dataset.'+args.dataset)
Dataloder = dataset.Dataloder
train_loader, test_loader = Dataloder(args).getloader()
Dataloader = dataset.Dataloader
train_loader, test_loader = Dataloader(args).getloader()
# init the model
models = importlib.import_module('model.'+args.model)
model = models.Net(args)
......
import torch
import encoding
# Get the model
model = encoding.models.get_model('fcn_resnet50_ade', pretrained=True).cuda()
model.eval()
# Prepare the image
url = 'https://github.com/zhanghang1989/image-data/blob/master/' + \
'encoding/segmentation/ade20k/ADE_val_00001142.jpg?raw=true'
filename = 'example.jpg'
img = encoding.utils.load_image(
encoding.utils.download(url, filename)).cuda().unsqueeze(0)
# Make prediction
output = model.evaluate(img)
predict = torch.max(output, 1)[1].cpu().numpy() + 1
# Get color pallete for visualization
mask = encoding.utils.get_mask_pallete(predict, 'ade20k')
mask.save('output.png')
......@@ -44,7 +44,7 @@ def test(args):
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
if args.cuda else {}
test_data = data.DataLoader(testset, batch_size=args.batch_size,
test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
drop_last=False, shuffle=False,
collate_fn=test_batchify_fn, **kwargs)
# model
......@@ -105,8 +105,8 @@ def test(args):
with torch.no_grad():
correct, labeled, inter, union = eval_batch(image, dst, evaluator, args.eval)
if args.eval:
total_correct += correct
total_label += labeled
total_correct += correct.astype('int64')
total_label += labeled.astype('int64')
total_inter += inter.astype('int64')
total_union += union.astype('int64')
pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
......
import importlib
import torch
import encoding
from option import Options
from torch.autograd import Variable
if __name__ == "__main__":
args = Options().parse()
model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux,
se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d)
print('Creating the model:')
print(model)
model.cuda()
x = Variable(torch.Tensor(4, 3, 480, 480)).cuda()
with torch.no_grad():
out = model(x)
for y in out:
print(y.size())
......@@ -60,18 +60,6 @@ class Trainer():
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
optimizer.load_state_dict(checkpoint['optimizer'])
best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
......@@ -82,6 +70,21 @@ class Trainer():
if args.cuda:
self.model = DataParallelModel(self.model).cuda()
self.criterion = DataParallelCriterion(self.criterion).cuda()
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# lr scheduler
self.scheduler = utils.LR_Scheduler(args, len(self.trainloader))
self.best_pred = 0.0
......
......@@ -32,6 +32,17 @@ def download_ade(path, overwrite=False):
else:
shutil.move(filename, os.path.join(path, 'VOCdevkit/VOC2010/'+os.path.basename(filename)))
def install_pcontext_api():
repo_url = "https://github.com/zhanghang1989/detail-api"
os.system("git clone " + repo_url)
os.system("cd detail-api/PythonAPI/ && python setup.py install")
shutil.rmtree('detail-api')
try:
import detail
except Exception:
print("Installing PASCAL Context API failed, please install it manually %s"%(repo_url))
if __name__ == '__main__':
args = parse_args()
mkdir(os.path.expanduser('~/.encoding/data'))
......@@ -42,3 +53,4 @@ if __name__ == '__main__':
os.symlink(args.download_dir, _TARGET_DIR)
else:
download_ade(_TARGET_DIR, overwrite=False)
install_pcontext_api()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment