Commit 97243508 authored by sunxx1's avatar sunxx1
Browse files

添加DBnet代码

parents
# -*- coding: utf-8 -*-
# @Time : 2019/12/4 14:39
# @Author : zhoujun
import torch
import torch.nn as nn
class BalanceCrossEntropyLoss(nn.Module):
'''
Balanced cross entropy loss.
Shape:
- Input: :math:`(N, 1, H, W)`
- GT: :math:`(N, 1, H, W)`, same shape as the input
- Mask: :math:`(N, H, W)`, same spatial shape as the input
- Output: scalar.
Examples::
>>> m = nn.Sigmoid()
>>> loss = nn.BCELoss()
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> output = loss(m(input), target)
>>> output.backward()
'''
def __init__(self, negative_ratio=3.0, eps=1e-6):
super(BalanceCrossEntropyLoss, self).__init__()
self.negative_ratio = negative_ratio
self.eps = eps
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: torch.Tensor,
return_origin=False):
'''
Args:
pred: shape :math:`(N, 1, H, W)`, the prediction of network
gt: shape :math:`(N, 1, H, W)`, the target
mask: shape :math:`(N, H, W)`, the mask indicates positive regions
'''
positive = (gt * mask).byte()
negative = ((1 - gt) * mask).byte()
positive_count = int(positive.float().sum())
negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = loss * positive.float()
negative_loss = loss * negative.float()
# negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count)
negative_loss, _ = negative_loss.view(-1).topk(negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
if return_origin:
return balance_loss, loss
return balance_loss
class DiceLoss(nn.Module):
'''
Loss function from https://arxiv.org/abs/1707.03237,
where iou computation is introduced heatmap manner to measure the
diversity bwtween tow heatmaps.
'''
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, pred: torch.Tensor, gt, mask, weights=None):
'''
pred: one or two heatmaps of shape (N, 1, H, W),
the losses of tow heatmaps are added together.
gt: (N, 1, H, W)
mask: (N, H, W)
'''
return self._compute(pred, gt, mask, weights)
def _compute(self, pred, gt, mask, weights):
if pred.dim() == 4:
pred = pred[:, 0, :, :]
gt = gt[:, 0, :, :]
assert pred.shape == gt.shape
assert pred.shape == mask.shape
if weights is not None:
assert weights.shape == mask.shape
mask = weights * mask
intersection = (pred * gt * mask).sum()
union = (pred * mask).sum() + (gt * mask).sum() + self.eps
loss = 1 - 2.0 * intersection / union
assert loss <= 1
return loss
class MaskL1Loss(nn.Module):
def __init__(self, eps=1e-6):
super(MaskL1Loss, self).__init__()
self.eps = eps
def forward(self, pred: torch.Tensor, gt, mask):
loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
return loss
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:57
# @Author : zhoujun
from addict import Dict
from torch import nn
import torch.nn.functional as F
from models.backbone import build_backbone
from models.neck import build_neck
from models.head import build_head
class Model(nn.Module):
def __init__(self, model_config: dict):
"""
PANnet
:param model_config: 模型配置
"""
super().__init__()
model_config = Dict(model_config)
backbone_type = model_config.backbone.pop('type')
neck_type = model_config.neck.pop('type')
head_type = model_config.head.pop('type')
self.backbone = build_backbone(backbone_type, **model_config.backbone)
self.neck = build_neck(neck_type, in_channels=self.backbone.out_channels, **model_config.neck)
self.head = build_head(head_type, in_channels=self.neck.out_channels, **model_config.head)
self.name = f'{backbone_type}_{neck_type}_{head_type}'
def forward(self, x):
_, _, H, W = x.size()
backbone_out = self.backbone(x)
neck_out = self.neck(backbone_out)
y = self.head(neck_out)
y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)
return y
if __name__ == '__main__':
import torch
device = torch.device('cpu')
x = torch.zeros(2, 3, 640, 640).to(device)
model_config = {
'backbone': {'type': 'resnest50', 'pretrained': True, "in_channels": 3},
'neck': {'type': 'FPN', 'inner_channels': 256}, # 分割头,FPN or FPEM_FFM
'head': {'type': 'DBHead', 'out_channels': 2, 'k': 50},
}
model = Model(model_config=model_config).to(device)
import time
tic = time.time()
y = model(x)
print(time.time() - tic)
print(y.shape)
print(model.name)
print(model)
# torch.save(model.state_dict(), 'PAN.pth')
# -*- coding: utf-8 -*-
# @Time : 2019/9/13 10:29
# @Author : zhoujun
import torch
import torch.nn.functional as F
from torch import nn
from models.basic import ConvBnRelu
class FPEM_FFM(nn.Module):
def __init__(self, in_channels, inner_channels=128, fpem_repeat=2, **kwargs):
"""
PANnet
:param in_channels: 基础网络输出的维度
"""
super().__init__()
self.conv_out = inner_channels
inplace = True
# reduce layers
self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
self.fpems = nn.ModuleList()
for i in range(fpem_repeat):
self.fpems.append(FPEM(self.conv_out))
self.out_channels = self.conv_out * 4
def forward(self, x):
c2, c3, c4, c5 = x
# reduce channel
c2 = self.reduce_conv_c2(c2)
c3 = self.reduce_conv_c3(c3)
c4 = self.reduce_conv_c4(c4)
c5 = self.reduce_conv_c5(c5)
# FPEM
for i, fpem in enumerate(self.fpems):
c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
if i == 0:
c2_ffm = c2
c3_ffm = c3
c4_ffm = c4
c5_ffm = c5
else:
c2_ffm += c2
c3_ffm += c3
c4_ffm += c4
c5_ffm += c5
# FFM
c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:])
c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:])
c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:])
Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1)
return Fy
class FPEM(nn.Module):
def __init__(self, in_channels=128):
super().__init__()
self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
def forward(self, c2, c3, c4, c5):
# up阶段
c4 = self.up_add1(self._upsample_add(c5, c4))
c3 = self.up_add2(self._upsample_add(c4, c3))
c2 = self.up_add3(self._upsample_add(c3, c2))
# down 阶段
c3 = self.down_add1(self._upsample_add(c3, c2))
c4 = self.down_add2(self._upsample_add(c4, c3))
c5 = self.down_add3(self._upsample_add(c5, c4))
return c2, c3, c4, c5
def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(SeparableConv2d, self).__init__()
self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1,
stride=stride, groups=in_channels)
self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# -*- coding: utf-8 -*-
# @Time : 2019/9/13 10:29
# @Author : zhoujun
import torch
import torch.nn.functional as F
from torch import nn
from models.basic import ConvBnRelu
class FPN(nn.Module):
def __init__(self, in_channels, inner_channels=256, **kwargs):
"""
:param in_channels: 基础网络输出的维度
:param kwargs:
"""
super().__init__()
inplace = True
self.conv_out = inner_channels
inner_channels = inner_channels // 4
# reduce layers
self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
# Smooth layers
self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.conv = nn.Sequential(
nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(self.conv_out),
nn.ReLU(inplace=inplace)
)
self.out_channels = self.conv_out
def forward(self, x):
c2, c3, c4, c5 = x
# Top-down
p5 = self.reduce_conv_c5(c5)
p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
p4 = self.smooth_p4(p4)
p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
p3 = self.smooth_p3(p3)
p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
p2 = self.smooth_p2(p2)
x = self._upsample_cat(p2, p3, p4, p5)
x = self.conv(x)
return x
def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y
def _upsample_cat(self, p2, p3, p4, p5):
h, w = p2.size()[2:]
p3 = F.interpolate(p3, size=(h, w))
p4 = F.interpolate(p4, size=(h, w))
p5 = F.interpolate(p5, size=(h, w))
return torch.cat([p2, p3, p4, p5], dim=1)
# -*- coding: utf-8 -*-
# @Time : 2020/6/5 11:34
# @Author : zhoujun
from .FPN import FPN
from .FPEM_FFM import FPEM_FFM
__all__ = ['build_neck']
support_neck = ['FPN', 'FPEM_FFM']
def build_neck(neck_name, **kwargs):
assert neck_name in support_neck, f'all support neck is {support_neck}'
neck = eval(neck_name)(**kwargs)
return neck
# -*- coding: utf-8 -*-
# @Time : 2019/12/5 15:17
# @Author : zhoujun
from .seg_detector_representer import SegDetectorRepresenter
def get_post_processing(config):
try:
cls = eval(config['type'])(**config['args'])
return cls
except:
return None
\ No newline at end of file
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
class SegDetectorRepresenter():
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5):
self.min_size = 3
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
def __call__(self, batch, pred, is_output_polygon=False):
'''
batch: (image, polygons, ignore_tags
batch: a dict produced by dataloaders.
image: tensor of shape (N, C, H, W).
polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
shape: the original shape of images.
filename: the original filenames of images.
pred:
binary: text region segmentation map, with shape (N, H, W)
thresh: [if exists] thresh hold prediction with shape (N, H, W)
thresh_binary: [if exists] binarized with threshhold, (N, H, W)
'''
pred = pred[:, 0, :, :]
segmentation = self.binarize(pred)
boxes_batch = []
scores_batch = []
for batch_index in range(pred.size(0)):
height, width = batch['shape'][batch_index]
if is_output_polygon:
boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
else:
boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
boxes_batch.append(boxes)
scores_batch.append(scores)
return boxes_batch, scores_batch
def binarize(self, pred):
return pred > self.thresh
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (H, W),
whose values are binarized as {0, 1}
'''
assert len(_bitmap.shape) == 2
bitmap = _bitmap.cpu().numpy() # The first channel
pred = pred.cpu().detach().numpy()
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.005 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
# _, sside = self.get_mini_boxes(contour)
# if sside < self.min_size:
# continue
score = self.box_score_fast(pred, contour.squeeze(1))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, unclip_ratio=self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box)
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (H, W),
whose values are binarized as {0, 1}
'''
assert len(_bitmap.shape) == 2
bitmap = _bitmap.cpu().numpy() # The first channel
pred = pred.cpu().detach().numpy()
height, width = bitmap.shape
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
scores = np.zeros((num_contours,), dtype=np.float32)
for index in range(num_contours):
contour = contours[index].squeeze(1)
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
score = self.box_score_fast(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
scores[index] = score
return boxes, scores
def unclip(self, box, unclip_ratio=1.5):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment