Unverified Commit 6a31be8f authored by YeShenglong1's avatar YeShenglong1 Committed by GitHub
Browse files

Add files via upload

parent 4fb17721
import copy
import torch
import torch.nn as nn
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.models import HEADS, build_head, build_loss
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from .base_map_head import BaseMapHead
import numpy as np
from ..augmentation.sythesis_det import NoiseSythesis
@HEADS.register_module(force=True)
class DGHead(BaseMapHead):
def __init__(self,
det_net_cfg=dict(),
gen_net_cfg=dict(),
loss_vert=dict(),
loss_face=dict(),
max_num_vertices=90,
top_p_gen_model=0.9,
sync_cls_avg_factor=True,
augmentation=False,
augmentation_kwargs=None,
joint_training=False,
**kwargs):
super().__init__()
# Heads
self.det_net = build_head(det_net_cfg)
self.gen_net = build_head(gen_net_cfg)
self.coord_dim = self.gen_net.coord_dim
# Loss params
self.bg_cls_weight = 1.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.max_num_vertices = max_num_vertices
self.top_p_gen_model = top_p_gen_model
self.fp16_enabled = False
self.augmentation = None
if augmentation:
augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size})
self.augmentation = NoiseSythesis(**augmentation_kwargs)
self.joint_training = joint_training
def forward(self, batch, img_metas=None, **kwargs):
'''
Args:
Returns:
outs (Dict):
'''
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs):
''' we use teacher force strategy'''
bbox_dict = self.det_net(context=context)
outs = dict(
bbox=bbox_dict,
)
losses_dict, det_match_idxs, det_match_gt_idxs = \
self.loss_det(batch, outs)
if only_det: return outs, losses_dict
if self.augmentation is not None:
polylines, bbox_flat =\
self.augmentation(batch['gen'],simple_aug=True)
if bbox_flat is None:
bbox_flat = batch['gen']['bbox_flat']
gen_input = dict(
lines_bs_idx=batch['gen']['lines_bs_idx'],
lines_cls=batch['gen']['lines_cls'],
bbox_flat=bbox_flat,
polylines=polylines,
polyline_masks=batch['gen']['polyline_masks']
)
else:
gen_input = batch['gen']
if self.joint_training:
# for down stream polyline
if 'lines' in bbox_dict[-1]:
# for fix anchor
pred_bbox = bbox_dict[-1]['lines'].detach()
elif 'bboxs' in bbox_dict[-1]:
# for rpv
pred_bbox = bbox_dict[-1]['bboxs'].detach()
else:
raise NotImplementedError
# changed to original gt order.
det_match_idx = det_match_idxs[-1]
det_match_gt_idx = det_match_gt_idxs[-1]
_bboxs = []
for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)):
_bboxs.append(bbox[match_idx])
_bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
_bboxs = torch.cat(_bboxs, dim=0)
# quantize the data
_bboxs = \
torch.round(_bboxs).type(torch.int32)
# gen_input['bbox_flat'] = _bboxs
remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)]
# for data efficient
for k in gen_input.keys():
if k == 'bbox_flat':
gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0)
else:
gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0)
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(gen_input, context=context)
outs.update(dict(
polylines=poly_dict,
))
if self.joint_training:
for k in batch['gen'].keys():
batch['gen'][k] = \
torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0)
gen_losses_dict = \
self.loss_gen(batch, outs)
losses_dict.update(gen_losses_dict)
return outs, losses_dict
def loss_det(self, gt: dict, pred: dict):
loss_dict = {}
# det
det_loss_dict, det_match_idx, det_match_gt_idx = \
self.det_net.loss(gt['det'], pred['bbox'])
for k, v in det_loss_dict.items():
loss_dict['det_'+k] = v
return loss_dict, det_match_idx, det_match_gt_idx
def loss_gen(self, gt: dict, pred: dict):
loss_dict = {}
# gen
gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])
for k, v in gen_loss_dict.items():
loss_dict['gen_'+k] = v
return loss_dict
def loss(self, gt: dict, pred: dict):
pass
@torch.no_grad()
def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs):
'''
num_samples_batch: number of sample per batch (batch size)
'''
outs = {}
bbox_dict = self.det_net(context=context)
bbox_dict = self.det_net.post_process(bbox_dict)
outs.update(bbox_dict)
if len(outs['lines_bs_idx']) == 0:
return None
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(outs,
context=context,
# max_sample_length=self.max_num_vertices,
max_sample_length=64,
top_p=self.top_p_gen_model,
gt_condition=gt_condition)
outs.update(poly_dict)
return outs
def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs):
'''
Args:
XXX
Outs:
XXX
'''
range_size = self.gen_net.canvas_size.cpu().numpy()
coord_dim = self.gen_net.coord_dim
gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen'
ret_list = []
for batch_idx in range(len(tokens)):
ret_dict_single = {}
# bbox
det_gt = None
if gts is not None:
det_gt, rec_groundtruth = pack_groundtruth(
batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim)
bbox_res = {
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
'token': tokens[batch_idx],
'scores': preds['scores'][batch_idx].detach().cpu().numpy(),
'labels': preds['labels'][batch_idx].detach().cpu().numpy(),
}
ret_dict_single.update(bbox_res)
# for gen results.
batch2seq = np.nonzero(
preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_dict_single.update({
'nline': len(batch2seq),
'vectors': []
})
for i in batch2seq:
pre = preds['polylines'][i].detach().cpu().numpy()
pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
valid_idx = np.nonzero(pre_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_dict_single['vectors'].append(line)
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list.append(ret_dict_single)
return ret_list
def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2):
if 'keypoints' in gts['det']:
gt_bbox = \
gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
else:
gt_bbox = \
gts['det']['bbox'][batch_idx].detach().cpu().numpy()
det_gt = {
'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(),
'bboxes': gt_bbox,
}
batch2seq = np.nonzero(
gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_groundtruth = {
'token': tokens[batch_idx],
'nline': len(batch2seq),
'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(),
'lines': [],
}
for i in batch2seq:
gt_line =\
gts['gen']['polylines'].detach().cpu().numpy()[i]
gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
if gen_net_name == 'gen_gmm':
valid_idx = np.nonzero(gt_msk)[0]
else:
valid_idx = np.nonzero(gt_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_groundtruth['lines'].append(line)
return det_gt, ret_groundtruth
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear
from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical
from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean)
from mmdet.models import HEADS
from .detr_bbox import DETRBboxHead
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
@HEADS.register_module(force=True)
class MapElementDetector(nn.Module):
def __init__(self,
canvas_size=(400, 200),
discrete_output=False,
separate_detect=False,
mode='xyxy',
bbox_size=None,
coord_dim=2,
kp_coord_dim=2,
num_classes=3,
in_channels=128,
num_query=100,
max_lines=50,
score_thre=0.2,
num_reg_fcs=2,
num_points=100,
iterative=False,
patch_size=None,
sync_cls_avg_factor=True,
transformer: dict = None,
positional_encoding: dict = None,
loss_cls: dict = None,
loss_reg: dict = None,
train_cfg: dict = None,):
super().__init__()
assigner = train_cfg['assigner']
self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.train_cfg = train_cfg
self.max_lines = max_lines
self.score_thre = score_thre
self.num_query = num_query
self.in_channels = in_channels
self.num_classes = num_classes
self.num_points = num_points
# branch
# if loss_cls.use_sigmoid:
if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes+1
self.iterative = iterative
self.num_reg_fcs = num_reg_fcs
self._build_transformer(transformer, positional_encoding)
# loss params
self.loss_cls = build_loss(loss_cls)
self.bg_cls_weight = 0.1
if self.loss_cls.use_sigmoid:
self.bg_cls_weight = 0.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.reg_loss = build_loss(loss_reg)
self.separate_detect = separate_detect
self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2
if bbox_size is not None:
self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz
self.kp_coord_dim = kp_coord_dim
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# add reg, cls head for each decoder layer
self._init_layers()
self._init_branch()
self.init_weights()
self._init_embedding()
def _init_layers(self):
"""Initialize some layer."""
self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims)
def _init_embedding(self):
self.label_embed = nn.Embedding(
self.num_classes, self.embed_dims)
self.img_coord_embed = nn.Linear(2, self.embed_dims)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims*2)
# for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding( self.bbox_size,
self.embed_dims*2)
def _init_branch(self,):
"""Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU())
if self.discrete_output:
reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,))
else:
reg_branch.append(nn.Linear(
self.embed_dims, self.coord_dim, bias=True,))
reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers
if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred)
else:
reg_branch = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
fc_cls = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.pre_branches = nn.ModuleDict([
('cls', fc_cls),
('reg', reg_branch), ])
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
for p in self.input_proj.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.transformer.init_weights()
# init prediction branch
for k, v in self.pre_branches.items():
for param in v.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
# focal loss init
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
# for last layer
if isinstance(self.pre_branches['cls'], nn.ModuleList):
for m in self.pre_branches['cls']:
nn.init.constant_(m.bias, bias_init)
else:
m = self.pre_branches['cls']
nn.init.constant_(m.bias, bias_init)
def _build_transformer(self, transformer, positional_encoding):
# transformer
self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding(
positional_encoding)
self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims
def _prepare_context(self, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
image_embeddings = context['bev_embeddings']
image_embeddings = self.input_proj(
image_embeddings) # only change feature size
# Pass images through encoder
device = image_embeddings.device
# Add 2D coordinate grid embedding
B, C, H, W = image_embeddings.shape
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
sequential_context_embeddings = image_embeddings.reshape(
B, C, H, W)
return (global_context_embedding, sequential_context_embeddings)
def forward(self, context, img_metas=None, multi_scale=False):
'''
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
img_metas
Outs:
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
'''
(global_context_embedding, sequential_context_embeddings) =\
self._prepare_context(context)
x = sequential_context_embeddings
B, C, H, W = x.shape
query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1)
bbox_embed = self.bbox_embedding.weight
query_embedding = query_embedding + bbox_embed[None,None]
query_embedding = query_embedding.view(B, -1, C*2)
img_masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(img_masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs, init_reference, inter_references = self.transformer(
[x,],
[img_masks.type(torch.bool)],
query_embedding,
[pos_embed],
reg_branches= self.reg_branches if self.iterative else None, # noqa:E501
cls_branches= None, # noqa:E501
)
outs_dec = hs.permute(0, 2, 1, 3)
outputs = []
for i, (query_feat) in enumerate(outs_dec):
if i == 0:
reference = init_reference
else:
reference = inter_references[i - 1]
outputs.append(self.get_prediction(i,query_feat,reference))
return outputs
def get_prediction(self, level, query_feat, reference):
bs, num_query, h = query_feat.shape
query_feat = query_feat.view(bs, -1, self.bbox_size,h)
ocls = self.pre_branches['cls'][level](query_feat.flatten(-2))
# ocls = ocls.mean(-2)
reference = inverse_sigmoid(reference)
reference = reference.view(bs, -1, self.bbox_size,self.coord_dim)
tmp = self.pre_branches['reg'][level](query_feat)
tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim]
lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2
lines = lines * self.canvas_size[:self.coord_dim]
lines = lines.flatten(-2)
return dict(
lines=lines, # [bs, num_query, bboxsize*2]
scores=ocls, # [bs, num_query, num_class]
embeddings= query_feat, # [bs, num_query, bbox_size, h]
)
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self,
score_pred,
lines_pred,
gt_labels,
gt_lines,
gt_bboxes_ignore=None):
"""
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
shape [num_query, num_points, 2].
gt_lines (Tensor):
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
shape [num_gt, ]
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_pred_lines = len(lines_pred)
# assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines,
labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets 0: foreground, 1: background
if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long)
else:
labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
if self.discrete_output:
lines_target = torch.zeros_like(lines_pred[..., 0]).long()
lines_weights = torch.zeros_like(lines_pred[..., 0])
else:
lines_target = torch.zeros_like(lines_pred)
lines_weights = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type(
lines_target.dtype)
lines_weights[pos_inds] = 1.0
n = lines_weights.sum(-1, keepdim=True)
lines_weights = lines_weights / n.masked_fill(n == 0, 1)
return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds, pos_gt_inds)
# @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
"""
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- lines_targets_list (list[Tensor]): Lines targets for all \
images.
- lines_weight_list (list[Tensor]): Lines weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
if self.separate_detect:
bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])]
class_label = torch.zeros_like(gts['bbox_mask']).long()
class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])]
else:
class_label = gts['class_label']
bbox = gts['bbox']
if self.discrete_output:
lines_pred = preds['lines'].logits
else:
lines_pred = preds['lines']
bbox = [b.float() for b in bbox]
(labels_list, label_weights_list,
lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply(
self._get_target_single,
preds['scores'], lines_pred,
class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict(
labels=labels_list,
label_weights=label_weights_list,
bboxs=lines_targets_list,
bboxs_weights=lines_weights_list,
)
return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self,
preds: dict,
gts: dict,
gt_bboxes_ignore_list=None,
reduction='none'):
"""
Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
# Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data
for k, v in new_gts.items():
new_gts[k] = torch.stack(v, dim=0)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
# Classification loss
if self.separate_detect:
loss_cls = self.bce_loss(
preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor)
else:
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
cls_labels = new_gts['labels'].reshape(-1)
cls_weights = new_gts['label_weights'].reshape(-1)
loss_cls = self.loss_cls(
cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# position NLL loss
if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos)
else:
loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
loss_dict = dict(
cls=loss_cls,
reg=loss_reg,
)
return loss_dict, pos_inds_list, pos_gt_inds_list
@force_fp32(apply_to=('gt_lines_list', 'preds_dicts'))
def loss(self,
gts: dict,
preds_dicts: dict,
gt_bboxes_ignore=None,
reduction='mean'):
"""
Loss Function.
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
losses, pos_inds_lists, pos_gt_inds_lists = multi_apply(
self.loss_single,
preds_dicts,
gts=gts,
gt_bboxes_ignore_list=gt_bboxes_ignore,
reduction=reduction)
# Format the losses
loss_dict = dict()
# loss from the last decoder layer
for k, v in losses[-1].items():
loss_dict[k] = v
# Loss from other decoder layers
num_dec_layer = 0
for loss in losses[:-1]:
for k, v in loss.items():
loss_dict[f'd{num_dec_layer}.{k}'] = v
num_dec_layer += 1
return loss_dict, pos_inds_lists, pos_gt_inds_lists
def post_process(self, preds_dicts: list, **kwargs):
'''
Args:
preds_dicts:
scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
XXX
'''
preds = preds_dicts[-1]
batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0)
device = batched_cls_scores.device
result_dict = {
'bbox': [],
'scores': [],
'labels': [],
'bbox_flat': [],
'lines_cls': [],
'lines_bs_idx': [],
}
for i in range(batch_size):
cls_scores = batched_cls_scores[i]
det_preds = batched_lines_preds[i]
max_num = self.max_lines
if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid()
scores, valid_idx = cls_scores.view(-1).topk(max_num)
det_labels = valid_idx % self.num_classes
valid_idx = valid_idx // self.num_classes
det_preds = det_preds[valid_idx]
else:
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
scores, valid_idx = scores.topk(max_num)
det_preds = det_preds[valid_idx]
det_labels = det_labels[valid_idx]
nline = len(valid_idx)
result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline)
# for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0)
# quantize the data
result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32)
result_dict['lines_cls'] = torch.cat(
result_dict['labels'], dim=0).long()
result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long()
return result_dict
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from mmdet.models import HEADS
from .detgen_utils.causal_trans import (CausalTransformerDecoder,
CausalTransformerDecoderLayer)
from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask,
quantize_verts, top_k_logits, top_p_logits)
from mmcv.runner import force_fp32, auto_fp16
@HEADS.register_module(force=True)
class PolylineGenerator(nn.Module):
"""
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def __init__(self,
in_channels,
encoder_config,
decoder_config,
class_conditional=True,
num_classes=55,
decoder_cross_attention=True,
use_discrete_vertex_embeddings=True,
condition_points_num=3,
coord_dim=2,
canvas_size=(400, 200),
max_seq_length=500,
name='gen_model'):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super(PolylineGenerator, self).__init__()
self.embedding_dim = decoder_config['layer_config']['d_model']
self.class_conditional = class_conditional
self.num_classes = num_classes
self.max_seq_length = max_seq_length
self.decoder_cross_attention = decoder_cross_attention
self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings
self.condition_points_num = condition_points_num
self.fp16_enabled = False
self.coord_dim = coord_dim # if we use xyz else 2 when we use xy
self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# initialize the model
self._project_to_logits = nn.Linear(
self.embedding_dim,
max(canvas_size) + 1, # + 1 for stopping token. use_bias=True,
)
self.input_proj = nn.Conv2d(
in_channels, self.embedding_dim, kernel_size=1)
decoder_layer = CausalTransformerDecoderLayer(
**decoder_config.pop('layer_config'))
self.decoder = CausalTransformerDecoder(
decoder_layer, **decoder_config)
self._init_embedding()
self.init_weights()
def _init_embedding(self):
if self.class_conditional:
self.label_embed = nn.Embedding(
self.num_classes, self.embedding_dim)
self.coord_embed = nn.Embedding(self.coord_dim, self.embedding_dim)
self.pos_embeddings = nn.Embedding(
self.max_seq_length, self.embedding_dim)
# to indicate the role of the position is the start of the line or the end of it.
self.bbox_context_embed = \
nn.Embedding(self.condition_points_num, self.embedding_dim)
self.img_coord_embed = nn.Linear(2, self.embedding_dim)
# initialize the verteices embedding
if self.use_discrete_vertex_embeddings:
self.vertex_embed = nn.Embedding(
max(self.canvas_size) + 1, self.embedding_dim)
else:
self.vertex_embed = nn.Linear(1, self.embedding_dim)
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _embed_kps(self, bbox):
bbox_len = bbox.shape[-1]
# Bbox_context
bbox_embedding = self.bbox_context_embed(
(torch.arange(bbox_len, device=bbox.device) / self.kp_coord_dim).floor().long())
# Coord indicators (x, y)
coord_embeddings = self.coord_embed(
torch.arange(bbox_len, device=bbox.device) % self.kp_coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(bbox)
return vert_embeddings + (bbox_embedding+coord_embeddings)[None]
def _prepare_context(self, batch, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
if self.class_conditional:
global_context_embedding = self.label_embed(batch['lines_cls'])
bbox_embeddings = self._embed_kps(batch['bbox_flat'])
if global_context_embedding is not None:
global_context_embedding = torch.cat(
[global_context_embedding[:, None], bbox_embeddings], dim=1)
# Pass images through encoder
image_embeddings = assign_bev(
context['bev_embeddings'], batch['lines_bs_idx'])
image_embeddings = self.input_proj(image_embeddings)
device = image_embeddings.device
# Add 2D coordinate grid embedding
H, W = image_embeddings.shape[2:]
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
B = image_embeddings.shape[0]
sequential_context_embeddings = image_embeddings.reshape(
B, self.embedding_dim, -1).permute(0, 2, 1)
return (global_context_embedding,
sequential_context_embeddings)
def _embed_inputs(self, seqs, condition_embedding=None):
"""Embeds face sequences and adds within and between face positions.
Args:
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
embeddings: B, seqlen, h
"""
B, seq_len = seqs.shape[:2]
# Position embeddings
pos_embeddings = self.pos_embeddings(
(torch.arange(seq_len, device=seqs.device) / self.coord_dim).floor().long()) # seq_len, h
# Coord indicators (x, y, z(optional))
coord_embeddings = self.coord_embed(
torch.arange(seq_len, device=seqs.device) % self.coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(seqs)
# Aggregate embeddings
embeddings = vert_embeddings + \
(coord_embeddings+pos_embeddings)[None]
embeddings = torch.cat([condition_embedding, embeddings], dim=1)
return embeddings
def forward(self, batch: dict, **kwargs):
"""
Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def sperate_forward(self, batch, context, **kwargs):
polyline_length = batch['polyline_masks'].sum(-1)
c1, c2, revert_idx, size = get_chunk_idx(polyline_length)
sizes = [size, polyline_length.max()]
polyline_logits = []
for c_idx, size in zip([c1,c2], sizes):
new_batch = assign_batch(batch,c_idx, size)
_poly_logits = self._forward_train(new_batch,context,**kwargs)
polyline_logits.append(_poly_logits)
# maybe imporve the speed
for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)):
if size < sizes[1]:
_poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0)
polyline_logits[i] = _poly_logits
polyline_logits = torch.cat(polyline_logits,0)
polyline_logits = polyline_logits[revert_idx]
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
def forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
if False:
polyline_logits = self._forward_train(batch, context, **kwargs)
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
else:
return self.sperate_forward(batch, context, **kwargs)
def _forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
global_context, seq_context = self._prepare_context(
batch, context)
logits = self.body(
# Last element not used for preds
batch['polylines'][:, :-1],
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
return_logits=True,
is_training=self.training)
return logits
@force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache'))
def body(self,
seqs,
global_context_embedding=None,
sequential_context_embeddings=None,
temperature=1.,
top_k=0,
top_p=1.,
cache=None,
return_logits=False,
is_training=True):
"""
Outputs categorical dist for vertex indices.
Body of the face model
"""
# Embed inputs
condition_len = global_context_embedding.shape[1]
decoder_inputs = self._embed_inputs(
seqs, global_context_embedding)
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
decoder_inputs = decoder_inputs.transpose(0, 1)
if sequential_context_embeddings is not None:
sequential_context_embeddings = sequential_context_embeddings.transpose(
0, 1)
causal_msk = None
if is_training:
causal_msk = generate_square_subsequent_mask(
decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device)
decoder_outputs, cache = self.decoder(
tgt=decoder_inputs,
cache=cache,
memory=sequential_context_embeddings,
causal_mask=causal_msk,
)
decoder_outputs = decoder_outputs.transpose(0, 1)
# since we only need the predict seq
decoder_outputs = decoder_outputs[:, condition_len-1:]
# Get logits and optionally process for sampling
logits = self._project_to_logits(decoder_outputs)
# y mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
vertices_mask_y[0] = False # y position doesn't have stop sign
logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \
vertices_mask_y - ~vertices_mask_y*1e9
if self.coord_dim > 2:
# z mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_z = (_vert_mask < self.canvas_size[2]+1)
vertices_mask_z[0] = False # y position doesn't have stop sign
logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \
vertices_mask_z - ~vertices_mask_z*1e9
logits = logits/temperature
logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p)
if return_logits:
return logits
cat_dist = Categorical(logits=logits)
return cat_dist, cache
@force_fp32(apply_to=('pred'))
def loss(self, gt: dict, pred: dict):
weight = gt['polyline_weights']
mask = gt['polyline_masks']
loss = -torch.sum(
pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum()
return {'seq': loss}
def inference(self,
batch: dict,
context: dict,
max_sample_length=None,
temperature=1.,
top_k=0,
top_p=1.,
only_return_complete=False,
gt_condition=False,
**kwargs):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
# prepare the conditional variable
global_context, seq_context = self._prepare_context(
batch, context)
device = global_context.device
batch_size = global_context.shape[0]
# While loop sampling with caching
samples = torch.empty(
[batch_size, 0], dtype=torch.int32, device=device)
max_sample_length = max_sample_length or self.max_seq_length
seq_len = max_sample_length*self.coord_dim+1
cache = None
decoded_tokens = \
torch.zeros((batch_size,seq_len),
device=device,dtype=torch.long)
remain_idx = torch.arange(batch_size, device=device)
for i in range(seq_len):
# While-loop body for autoregression calculation.
pred_dist, cache = self.body(
samples,
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
cache=cache,
temperature=temperature,
top_k=top_k,
top_p=top_p,
is_training=False)
samples = pred_dist.sample()
decoded_tokens[remain_idx,i] = samples[:,-1]
# Stopping conditions for autoregressive calculation.
if not (decoded_tokens[:,:i+1] != 0).all(-1).any():
break
# update state, check the new position is zero.
valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0]
remain_idx = remain_idx[valid_idx]
cache = cache[:,:,valid_idx]
global_context = global_context[valid_idx]
seq_context = seq_context[valid_idx]
samples = samples[valid_idx]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens = decoded_tokens[:,:i+1]
outputs = self.post_process(decoded_tokens, seq_len,
device, only_return_complete)
return outputs
def post_process(self, polyline,
max_seq_len=None,
device=None,
only_return_complete=True):
'''
format the predictions
find the mask
'''
# Record completed samples
complete_samples = (polyline == 0).any(-1)
# Find number of faces
sample_seq_length = polyline.shape[-1]
_polyline_mask = torch.arange(sample_seq_length)[None].to(device)
# Get largest stopping point for incomplete samples.
valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length)
zero_inds = (polyline == 0).type(torch.int32).argmax(-1)
# Real length
valid_polyline_len[complete_samples] = zero_inds[complete_samples] + 1
polyline_mask = _polyline_mask < valid_polyline_len[:, None]
# Mask faces beyond stopping token with zeros
polyline = polyline*polyline_mask
# Pad to maximum size with zeros
pad_size = max_seq_len - sample_seq_length
polyline = F.pad(polyline, (0, pad_size))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# if only_return_complete:
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# outputs
outputs = {
'completed': complete_samples,
'polylines': polyline,
'polyline_masks': polyline_mask,
}
return outputs
def find_best_sperate_plan(idx,array):
h = array[-1] - array[idx]
w = idx
cost = h*w
return cost
def get_chunk_idx(polyline_length):
_polyline_length, polyline_length_idx = torch.sort(polyline_length)
costs = []
for i in range(len(_polyline_length)):
cost = find_best_sperate_plan(i,_polyline_length)
costs.append(cost)
seperate_point = torch.stack(costs).argmax()
chunk1 = polyline_length_idx[:seperate_point+1]
chunk2 = polyline_length_idx[seperate_point+1:]
revert_idx = torch.argsort(polyline_length_idx)
return chunk1, chunk2, revert_idx, _polyline_length[seperate_point]
def assign_bev(feat, idx):
return feat[idx]
def assign_batch(batch, idx, size):
new_batch = {}
for k,v in batch.items():
new_batch[k] = v[idx]
if new_batch[k].ndim > 1:
new_batch[k] = new_batch[k][:,:size]
return new_batch
import torch
from torch import nn as nn
from torch.nn import functional as F
from mmdet.models.losses import l1_loss
from mmdet.models.losses.utils import weighted_loss
import mmcv
from mmdet.models.builder import LOSSES
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert beta > 0
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return loss
@LOSSES.register_module()
class LinesLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super(LinesLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.beta = beta
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def bce(pred, label, class_weight=None):
"""
pred: B,nquery,npts
label: B,nquery,npts
"""
if label.numel() == 0:
return pred.sum() * 0
assert pred.size() == label.size()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class MasksLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MasksLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = bce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def ce(pred, label, class_weight=None):
"""
pred: B*nquery,npts
label: B*nquery,
"""
if label.numel() == 0:
return pred.sum() * 0
loss = F.cross_entropy(
pred, label, weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class LenLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(LenLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = ce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
return loss*self.loss_weight
\ No newline at end of file
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log
from mmdet.utils import get_root_logger
from mmdet3d.models.builder import DETECTORS
MAPPERS = DETECTORS
class BaseMapper(nn.Module, metaclass=ABCMeta):
"""Base class for mappers."""
def __init__(self):
super(BaseMapper, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None))
#@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
def forward_train(self, *args, **kwargs):
pass
#@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
#@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger)
def forward_test(self, *args, **kwargs):
"""
Args:
"""
if True:
self.simple_test()
else:
self.aug_test()
# @auto_fp16(apply_to=('img', ))
def forward(self, *args, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(*args, **kwargs)
else:
kwargs.pop('rescale')
return self.forward_test(*args, **kwargs)
def train_step(self, data_dict, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
loss, log_vars, num_samples = self(**data_dict)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
loss, log_vars, num_samples = self(**data)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def show_result(self,
**kwargs):
img = None
return img
\ No newline at end of file
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchvision.models.resnet import resnet18, resnet50
from mmdet3d.models.builder import (build_backbone, build_head,
build_neck)
from .base_mapper import BaseMapper, MAPPERS
@MAPPERS.register_module()
class VectorMapNet(BaseMapper):
def __init__(self,
backbone_cfg=dict(),
head_cfg=dict(
vert_net_cfg=dict(),
face_net_cfg=dict(),
),
neck_input_channels=128,
neck_cfg=None,
with_auxiliary_head=False,
only_det=False,
train_cfg=None,
test_cfg=None,
pretrained=None,
model_name=None, **kwargs):
super(VectorMapNet, self).__init__()
#Attribute
self.model_name = model_name
self.last_epoch = None
self.only_det = only_det
self.backbone = build_backbone(backbone_cfg)
if neck_cfg is not None:
self.neck_neck = build_backbone(neck_cfg.backbone)
self.neck_neck.conv1 = nn.Conv2d(
neck_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.neck_project = build_neck(neck_cfg.neck)
self.neck = self.multiscale_neck
else:
trunk = resnet18(pretrained=False, zero_init_residual=True)
self.neck = nn.Sequential(
nn.Conv2d(neck_input_channels, 64, kernel_size=(7, 7), stride=(
2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
dilation=1, ceil_mode=False),
trunk.layer1,
nn.Conv2d(64, 128, kernel_size=1, bias=False),
)
# BEV
if hasattr(self.backbone,'bev_w'):
self.bev_w = self.backbone.bev_w
self.bev_h = self.backbone.bev_h
self.head = build_head(head_cfg)
def multiscale_neck(self, bev_embedding):
multi_feat = self.neck_neck(bev_embedding)
multi_feat = self.neck_project(multi_feat)
return multi_feat
def forward_train(self, img, polys, points=None, img_metas=None, **kwargs):
'''
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- length: int
- label: int
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
Out:
loss, log_vars, num_sample
'''
# prepare labels and images
batch, img, img_metas, valid_idx, points = self.batch_data(
polys, img, img_metas, img.device, points)
# corner cases use hard code to prevent code fail
if self.last_epoch is None:
self.last_epoch = [batch, img, img_metas, valid_idx, points]
if len(valid_idx)==0:
batch, img, img_metas, valid_idx, points = self.last_epoch
else:
del self.last_epoch
self.last_epoch = [batch, img, img_metas, valid_idx, points]
# Backbone
_bev_feats = self.backbone(img, img_metas=img_metas, points=points)
img_shape = \
[_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
preds_dict, losses_dict = \
self.head(batch,
context={
'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape,
'raw_bev_embeddings': _bev_feats},
only_det=self.only_det)
# format outputs
loss = 0
for name, var in losses_dict.items():
loss = loss + var
# update the log
log_vars = {k: v.item() for k, v in losses_dict.items()}
log_vars.update({'total': loss.item()})
num_sample = img.size(0)
return loss, log_vars, num_sample
@torch.no_grad()
def forward_test(self, img, polys=None, points=None, img_metas=None, **kwargs):
'''
inference pipeline
'''
# prepare labels and images
token = []
for img_meta in img_metas:
token.append(img_meta['token'])
_bev_feats = self.backbone(img, img_metas, points=points)
img_shape = [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
context = {'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, # XXX
'raw_bev_embeddings': _bev_feats}
preds_dict = self.head(batch={},
context=context,
condition_on_det=True,
gt_condition=False,
only_det=self.only_det)
# Hard Code
if preds_dict is None:
return [None]
results_list = self.head.post_process(preds_dict, token, only_det=self.only_det)
return results_list
def batch_data(self, polys, imgs, img_metas, device, points=None):
# filter none vector's case
valid_idx = [i for i in range(len(polys)) if len(polys[i])]
imgs = imgs[valid_idx]
img_metas = [img_metas[i] for i in valid_idx]
polys = [polys[i] for i in valid_idx]
if points is not None:
points = [points[i] for i in valid_idx]
points = self.batch_points(points)
if len(valid_idx) == 0:
return None, None, None, valid_idx, None
batch = {}
batch['det'] = format_det(polys,device)
batch['gen'] = format_gen(polys,device)
return batch, imgs, img_metas, valid_idx, points
def batch_points(self, points):
pad_points = pad_sequence(points, batch_first=True)
points_mask = torch.zeros_like(pad_points[:,:,0]).bool()
for i in range(len(points)):
valid_num = points[i].shape[0]
points_mask[i][:valid_num] = True
return (pad_points, points_mask)
def format_det(polys, device):
batch = {
'class_label':[],
'batch_idx':[],
'bbox': [],
}
for batch_idx, poly in enumerate(polys):
keypoint_label = torch.from_numpy(poly['det_label']).to(device)
keypoint = torch.from_numpy(poly['keypoint']).to(device)
batch['class_label'].append(keypoint_label)
batch['bbox'].append(keypoint)
return batch
def format_gen(polys,device):
line_cls = []
polylines, polyline_masks, polyline_weights = [], [], []
bbox, line_cls, line_bs_idx = [], [], []
for batch_idx, poly in enumerate(polys):
# convert to cuda tensor
for k in poly.keys():
if isinstance(poly[k],np.ndarray):
poly[k] = torch.from_numpy(poly[k]).to(device)
else:
poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]]
line_cls += poly['gen_label']
line_bs_idx += [batch_idx]*len(poly['gen_label'])
# condition
bbox += poly['qkeypoint']
# out
polylines += poly['polylines']
polyline_masks += poly['polyline_masks']
polyline_weights += poly['polyline_weights']
batch = {}
batch['lines_bs_idx'] = torch.tensor(
line_bs_idx, dtype=torch.long, device=device)
batch['lines_cls'] = torch.tensor(
line_cls, dtype=torch.long, device=device)
batch['bbox_flat'] = torch.stack(bbox, 0)
# padding
batch['polylines'] = pad_sequence(polylines, batch_first=True)
batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True)
batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True)
return batch
\ No newline at end of file
from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_
from .base_transformer import PlaceHolderEncoder
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class PlaceHolderEncoder(nn.Module):
def __init__(self, *args, embed_dims=None, **kwargs):
super(PlaceHolderEncoder, self).__init__()
self.embed_dims = embed_dims
def forward(self, *args, query=None, **kwargs):
return query
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from torch.nn.init import normal_
from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import Transformer
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from .fp16_dattn import MultiScaleDeformableAttentionFp16
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args,
return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs):
super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.coord_dim = coord_dim
self.kp_coord_dim = kp_coord_dim
def forward(self,
query,
*args,
reference_points=None,
valid_ratios=None,
reg_branches=None,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = \
reference_points[:, :, None,:self.kp_coord_dim] * \
valid_ratios[:, None,:,:self.kp_coord_dim]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output = layer(
output,
*args,
reference_points=reference_points_input[...,:self.kp_coord_dim],
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
new_reference_points = tmp
new_reference_points[..., :self.kp_coord_dim] = tmp[
..., :self.kp_coord_dim] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
reference_points[...,-1] = tmp[...,-1].sigmoid().detach()
reference_points[...,:self.coord_dim] = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER.register_module()
class DeformableDetrTransformer_(Transformer):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
as_two_stage=False,
num_feature_levels=1,
two_stage_num_proposals=300,
coord_dim=2,
**kwargs):
super(DeformableDetrTransformer_, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims
self.coord_dim = coord_dim
self.init_layers()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2,
self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else:
self.reference_points_embed = nn.Linear(self.embed_dims, self.coord_dim)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
elif isinstance(m,MultiScaleDeformableAttentionFp16):
m.init_weights()
if not self.as_two_stage:
xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)
normal_(self.level_embeds)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self,
proposals,
num_pos_feats=128,
temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
dim=4).flatten(2)
return pos
def forward(self,
mlvl_feats,
mlvl_masks,
query_embed,
mlvl_pos_embeds,
reg_branches=None,
cls_branches=None,
**kwargs):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
# reference_points = \
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# query=feat_flatten,
# key=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# **kwargs)
memory = feat_flatten.permute(1, 0, 2)
bs, _, c = memory.shape
query_pos, query = torch.split(query_embed, c, dim=-1)
reference_points = self.reference_points_embed(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
\ No newline at end of file
from turtle import forward
import warnings
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_attention
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from torch.cuda.amp import custom_bwd, custom_fwd
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFp16(BaseModule):
def __init__(self, attn_cfg=None,init_cfg=None,**kwarg):
super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg)
# import ipdb; ipdb.set_trace()
self.deformable_attention = build_attention(attn_cfg)
self.deformable_attention.init_weights()
self.fp16_enabled = False
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity'))
def forward(self, query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
# import ipdb; ipdb.set_trace()
return self.deformable_attention(query,
key=key,
value=value,
identity=identity,
query_pos=query_pos,
key_padding_mask=key_padding_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,**kwargs)
class MultiScaleDeformableAttnFunctionFp32(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFP32(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunctionFp32.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment