Commit 50dd7d3e authored by dengjb's avatar dengjb
Browse files

update

parents
Pipeline #3040 canceled with stages
from .utils import IntermediateLayerGetter
from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
from .backbone import (
resnet,
mobilenetv2,
hrnetv2,
xception
)
def _segm_hrnet(name, backbone_name, num_classes, pretrained_backbone):
backbone = hrnetv2.__dict__[backbone_name](pretrained_backbone)
# HRNetV2 config:
# the final output channels is dependent on highest resolution channel config (c).
# output of backbone will be the inplanes to assp:
hrnet_channels = int(backbone_name.split('_')[-1])
inplanes = sum([hrnet_channels * 2 ** i for i in range(4)])
low_level_planes = 256 # all hrnet version channel output from bottleneck is the same
aspp_dilate = [12, 24, 36] # If follow paper trend, can put [24, 48, 72].
if name=='deeplabv3plus':
return_layers = {'stage4': 'out', 'layer1': 'low_level'}
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
elif name=='deeplabv3':
return_layers = {'stage4': 'out'}
classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers, hrnet_flag=True)
model = DeepLabV3(backbone, classifier)
return model
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
if output_stride==8:
replace_stride_with_dilation=[False, True, True]
aspp_dilate = [12, 24, 36]
else:
replace_stride_with_dilation=[False, False, True]
aspp_dilate = [6, 12, 18]
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=replace_stride_with_dilation)
inplanes = 2048
low_level_planes = 256
if name=='deeplabv3plus':
return_layers = {'layer4': 'out', 'layer1': 'low_level'}
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
elif name=='deeplabv3':
return_layers = {'layer4': 'out'}
classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
model = DeepLabV3(backbone, classifier)
return model
def _segm_xception(name, backbone_name, num_classes, output_stride, pretrained_backbone):
if output_stride==8:
replace_stride_with_dilation=[False, False, True, True]
aspp_dilate = [12, 24, 36]
else:
replace_stride_with_dilation=[False, False, False, True]
aspp_dilate = [6, 12, 18]
backbone = xception.xception(pretrained= 'imagenet' if pretrained_backbone else False, replace_stride_with_dilation=replace_stride_with_dilation)
inplanes = 2048
low_level_planes = 128
if name=='deeplabv3plus':
return_layers = {'conv4': 'out', 'block1': 'low_level'}
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
elif name=='deeplabv3':
return_layers = {'conv4': 'out'}
classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
model = DeepLabV3(backbone, classifier)
return model
def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
if output_stride==8:
aspp_dilate = [12, 24, 36]
else:
aspp_dilate = [6, 12, 18]
backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)
# rename layers
backbone.low_level_features = backbone.features[0:4]
backbone.high_level_features = backbone.features[4:-1]
backbone.features = None
backbone.classifier = None
inplanes = 320
low_level_planes = 24
if name=='deeplabv3plus':
return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
elif name=='deeplabv3':
return_layers = {'high_level_features': 'out'}
classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
model = DeepLabV3(backbone, classifier)
return model
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
if backbone=='mobilenetv2':
model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
elif backbone.startswith('resnet'):
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
elif backbone.startswith('hrnetv2'):
model = _segm_hrnet(arch_type, backbone, num_classes, pretrained_backbone=pretrained_backbone)
elif backbone=='xception':
model = _segm_xception(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
else:
raise NotImplementedError
return model
# Deeplab v3
def deeplabv3_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
return _load_model('deeplabv3', 'hrnetv2_48', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
def deeplabv3_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
return _load_model('deeplabv3', 'hrnetv2_32', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
"""Constructs a DeepLabV3 model with a MobileNetv2 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3_xception(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
"""Constructs a DeepLabV3 model with a Xception backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'xception', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
# Deeplab v3+
def deeplabv3plus_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
return _load_model('deeplabv3plus', 'hrnetv2_48', num_classes, output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3plus_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
return _load_model('deeplabv3plus', 'hrnetv2_32', num_classes, output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3+ model with a ResNet-101 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
def deeplabv3plus_xception(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3+ model with a Xception backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'xception', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
\ No newline at end of file
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone, classifier):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
def forward(self, x):
input_shape = x.shape[-2:]
features = self.backbone(x)
x = self.classifier(features)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
return x
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
Examples::
>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
def __init__(self, model, return_layers, hrnet_flag=False):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
self.hrnet_flag = hrnet_flag
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
if name == 'transition1': # in transition1, you need to split the module to two streams first
x = [trans(x) for trans in module]
else: # all other transition is just an extra one stream split
x.append(module(x[-1]))
else: # other models (ex:resnet,mobilenet) are convolutions in series.
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
x = torch.cat([x[0], x1, x2, x3], dim=1)
out[out_name] = x
else:
out[out_name] = x
return out
from torch.utils.data import dataset
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, cityscapes
from torchvision import transforms as T
from metrics import StreamSegMetrics
import torch
import torch.nn as nn
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from glob import glob
def get_argparser():
parser = argparse.ArgumentParser()
# Datset Options
parser.add_argument("--input", type=str, required=True,
help="path to a single image or image directory")
parser.add_argument("--dataset", type=str, default='voc',
choices=['voc', 'cityscapes'], help='Name of training set')
# Deeplab Options
available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
not (name.startswith("__") or name.startswith('_')) and callable(
network.modeling.__dict__[name])
)
parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
choices=available_models, help='model name')
parser.add_argument("--separable_conv", action='store_true', default=False,
help="apply separable conv to decoder and aspp")
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])
# Train Options
parser.add_argument("--save_val_results_to", default=None,
help="save segmentation results to the specified dir")
parser.add_argument("--crop_val", action='store_true', default=False,
help='crop validation (default: False)')
parser.add_argument("--val_batch_size", type=int, default=4,
help='batch size for validation (default: 4)')
parser.add_argument("--crop_size", type=int, default=513)
parser.add_argument("--ckpt", default=None, type=str,
help="resume from checkpoint")
parser.add_argument("--gpu_id", type=str, default='0',
help="GPU ID")
return parser
def main():
opts = get_argparser().parse_args()
if opts.dataset.lower() == 'voc':
opts.num_classes = 21
decode_fn = VOCSegmentation.decode_target
elif opts.dataset.lower() == 'cityscapes':
opts.num_classes = 19
decode_fn = Cityscapes.decode_target
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)
# Setup dataloader
image_files = []
if os.path.isdir(opts.input):
for ext in ['png', 'jpeg', 'jpg', 'JPEG']:
files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True)
if len(files)>0:
image_files.extend(files)
elif os.path.isfile(opts.input):
image_files.append(opts.input)
# Set up model (all models are 'constructed at network.modeling)
model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
if opts.separable_conv and 'plus' in opts.model:
network.convert_to_separable_conv(model.classifier)
utils.set_bn_momentum(model.backbone, momentum=0.01)
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.to(device)
print("Resume model from %s" % opts.ckpt)
del checkpoint
else:
print("[!] Retrain")
model = nn.DataParallel(model)
model.to(device)
#denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images
if opts.crop_val:
transform = T.Compose([
T.Resize(opts.crop_size),
T.CenterCrop(opts.crop_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
else:
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
if opts.save_val_results_to is not None:
os.makedirs(opts.save_val_results_to, exist_ok=True)
with torch.no_grad():
model = model.eval()
for img_path in tqdm(image_files):
ext = os.path.basename(img_path).split('.')[-1]
img_name = os.path.basename(img_path)[:-len(ext)-1]
img = Image.open(img_path).convert('RGB')
img = transform(img).unsqueeze(0) # To tensor of NCHW
img = img.to(device)
pred = model(img).max(1)[1].cpu().numpy()[0] # HW
colorized_preds = decode_fn(pred).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds)
if opts.save_val_results_to:
colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment