Commit f3b13cad authored by yeshenglong1's avatar yeshenglong1
Browse files

UpDate README.md

parent 0797920d
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_positional_encoding from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from mmdet.models import HEADS, build_head, build_loss from mmdet.models import HEADS, build_head, build_loss
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid from mmdet.models.utils.transformer import inverse_sigmoid
from .base_map_head import BaseMapHead from .base_map_head import BaseMapHead
import numpy as np import numpy as np
from ..augmentation.sythesis_det import NoiseSythesis from ..augmentation.sythesis_det import NoiseSythesis
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class DGHead(BaseMapHead): class DGHead(BaseMapHead):
def __init__(self, def __init__(self,
det_net_cfg=dict(), det_net_cfg=dict(),
gen_net_cfg=dict(), gen_net_cfg=dict(),
loss_vert=dict(), loss_vert=dict(),
loss_face=dict(), loss_face=dict(),
max_num_vertices=90, max_num_vertices=90,
top_p_gen_model=0.9, top_p_gen_model=0.9,
sync_cls_avg_factor=True, sync_cls_avg_factor=True,
augmentation=False, augmentation=False,
augmentation_kwargs=None, augmentation_kwargs=None,
joint_training=False, joint_training=False,
**kwargs): **kwargs):
super().__init__() super().__init__()
# Heads # Heads
self.det_net = build_head(det_net_cfg) self.det_net = build_head(det_net_cfg)
self.gen_net = build_head(gen_net_cfg) self.gen_net = build_head(gen_net_cfg)
self.coord_dim = self.gen_net.coord_dim self.coord_dim = self.gen_net.coord_dim
# Loss params # Loss params
self.bg_cls_weight = 1.0 self.bg_cls_weight = 1.0
self.sync_cls_avg_factor = sync_cls_avg_factor self.sync_cls_avg_factor = sync_cls_avg_factor
self.max_num_vertices = max_num_vertices self.max_num_vertices = max_num_vertices
self.top_p_gen_model = top_p_gen_model self.top_p_gen_model = top_p_gen_model
self.fp16_enabled = False self.fp16_enabled = False
self.augmentation = None self.augmentation = None
if augmentation: if augmentation:
augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size}) augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size})
self.augmentation = NoiseSythesis(**augmentation_kwargs) self.augmentation = NoiseSythesis(**augmentation_kwargs)
self.joint_training = joint_training self.joint_training = joint_training
def forward(self, batch, img_metas=None, **kwargs): def forward(self, batch, img_metas=None, **kwargs):
''' '''
Args: Args:
Returns: Returns:
outs (Dict): outs (Dict):
''' '''
if self.training: if self.training:
return self.forward_train(batch, **kwargs) return self.forward_train(batch, **kwargs)
else: else:
return self.inference(batch, **kwargs) return self.inference(batch, **kwargs)
def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs): def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs):
''' we use teacher force strategy''' ''' we use teacher force strategy'''
bbox_dict = self.det_net(context=context) bbox_dict = self.det_net(context=context)
outs = dict( outs = dict(
bbox=bbox_dict, bbox=bbox_dict,
) )
losses_dict, det_match_idxs, det_match_gt_idxs = \ losses_dict, det_match_idxs, det_match_gt_idxs = \
self.loss_det(batch, outs) self.loss_det(batch, outs)
if only_det: return outs, losses_dict if only_det: return outs, losses_dict
if self.augmentation is not None: if self.augmentation is not None:
polylines, bbox_flat =\ polylines, bbox_flat =\
self.augmentation(batch['gen'],simple_aug=True) self.augmentation(batch['gen'],simple_aug=True)
if bbox_flat is None: if bbox_flat is None:
bbox_flat = batch['gen']['bbox_flat'] bbox_flat = batch['gen']['bbox_flat']
gen_input = dict( gen_input = dict(
lines_bs_idx=batch['gen']['lines_bs_idx'], lines_bs_idx=batch['gen']['lines_bs_idx'],
lines_cls=batch['gen']['lines_cls'], lines_cls=batch['gen']['lines_cls'],
bbox_flat=bbox_flat, bbox_flat=bbox_flat,
polylines=polylines, polylines=polylines,
polyline_masks=batch['gen']['polyline_masks'] polyline_masks=batch['gen']['polyline_masks']
) )
else: else:
gen_input = batch['gen'] gen_input = batch['gen']
if self.joint_training: if self.joint_training:
# for down stream polyline # for down stream polyline
if 'lines' in bbox_dict[-1]: if 'lines' in bbox_dict[-1]:
# for fix anchor # for fix anchor
pred_bbox = bbox_dict[-1]['lines'].detach() pred_bbox = bbox_dict[-1]['lines'].detach()
elif 'bboxs' in bbox_dict[-1]: elif 'bboxs' in bbox_dict[-1]:
# for rpv # for rpv
pred_bbox = bbox_dict[-1]['bboxs'].detach() pred_bbox = bbox_dict[-1]['bboxs'].detach()
else: else:
raise NotImplementedError raise NotImplementedError
# changed to original gt order. # changed to original gt order.
det_match_idx = det_match_idxs[-1] det_match_idx = det_match_idxs[-1]
det_match_gt_idx = det_match_gt_idxs[-1] det_match_gt_idx = det_match_gt_idxs[-1]
_bboxs = [] _bboxs = []
for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)): for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)):
_bboxs.append(bbox[match_idx]) _bboxs.append(bbox[match_idx])
_bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])] _bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
_bboxs = torch.cat(_bboxs, dim=0) _bboxs = torch.cat(_bboxs, dim=0)
# quantize the data # quantize the data
_bboxs = \ _bboxs = \
torch.round(_bboxs).type(torch.int32) torch.round(_bboxs).type(torch.int32)
# gen_input['bbox_flat'] = _bboxs # gen_input['bbox_flat'] = _bboxs
remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)] remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)]
# for data efficient # for data efficient
for k in gen_input.keys(): for k in gen_input.keys():
if k == 'bbox_flat': if k == 'bbox_flat':
gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0) gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0)
else: else:
gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0) gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0)
if isinstance(context['bev_embeddings'],tuple): if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0] context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(gen_input, context=context) poly_dict = self.gen_net(gen_input, context=context)
outs.update(dict( outs.update(dict(
polylines=poly_dict, polylines=poly_dict,
)) ))
if self.joint_training: if self.joint_training:
for k in batch['gen'].keys(): for k in batch['gen'].keys():
batch['gen'][k] = \ batch['gen'][k] = \
torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0) torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0)
gen_losses_dict = \ gen_losses_dict = \
self.loss_gen(batch, outs) self.loss_gen(batch, outs)
losses_dict.update(gen_losses_dict) losses_dict.update(gen_losses_dict)
return outs, losses_dict return outs, losses_dict
def loss_det(self, gt: dict, pred: dict): def loss_det(self, gt: dict, pred: dict):
loss_dict = {} loss_dict = {}
# det # det
det_loss_dict, det_match_idx, det_match_gt_idx = \ det_loss_dict, det_match_idx, det_match_gt_idx = \
self.det_net.loss(gt['det'], pred['bbox']) self.det_net.loss(gt['det'], pred['bbox'])
for k, v in det_loss_dict.items(): for k, v in det_loss_dict.items():
loss_dict['det_'+k] = v loss_dict['det_'+k] = v
return loss_dict, det_match_idx, det_match_gt_idx return loss_dict, det_match_idx, det_match_gt_idx
def loss_gen(self, gt: dict, pred: dict): def loss_gen(self, gt: dict, pred: dict):
loss_dict = {} loss_dict = {}
# gen # gen
gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines']) gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])
for k, v in gen_loss_dict.items(): for k, v in gen_loss_dict.items():
loss_dict['gen_'+k] = v loss_dict['gen_'+k] = v
return loss_dict return loss_dict
def loss(self, gt: dict, pred: dict): def loss(self, gt: dict, pred: dict):
pass pass
@torch.no_grad() @torch.no_grad()
def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs): def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs):
''' '''
num_samples_batch: number of sample per batch (batch size) num_samples_batch: number of sample per batch (batch size)
''' '''
outs = {} outs = {}
bbox_dict = self.det_net(context=context) bbox_dict = self.det_net(context=context)
bbox_dict = self.det_net.post_process(bbox_dict) bbox_dict = self.det_net.post_process(bbox_dict)
outs.update(bbox_dict) outs.update(bbox_dict)
if len(outs['lines_bs_idx']) == 0: if len(outs['lines_bs_idx']) == 0:
return None return None
if isinstance(context['bev_embeddings'],tuple): if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0] context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(outs, poly_dict = self.gen_net(outs,
context=context, context=context,
# max_sample_length=self.max_num_vertices, # max_sample_length=self.max_num_vertices,
max_sample_length=64, max_sample_length=64,
top_p=self.top_p_gen_model, top_p=self.top_p_gen_model,
gt_condition=gt_condition) gt_condition=gt_condition)
outs.update(poly_dict) outs.update(poly_dict)
return outs return outs
def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs): def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs):
''' '''
Args: Args:
XXX XXX
Outs: Outs:
XXX XXX
''' '''
range_size = self.gen_net.canvas_size.cpu().numpy() range_size = self.gen_net.canvas_size.cpu().numpy()
coord_dim = self.gen_net.coord_dim coord_dim = self.gen_net.coord_dim
gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen' gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen'
ret_list = [] ret_list = []
for batch_idx in range(len(tokens)): for batch_idx in range(len(tokens)):
ret_dict_single = {} ret_dict_single = {}
# bbox # bbox
det_gt = None det_gt = None
if gts is not None: if gts is not None:
det_gt, rec_groundtruth = pack_groundtruth( det_gt, rec_groundtruth = pack_groundtruth(
batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim) batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim)
bbox_res = { bbox_res = {
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(), # 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt, # 'det_gt': det_gt,
'token': tokens[batch_idx], 'token': tokens[batch_idx],
'scores': preds['scores'][batch_idx].detach().cpu().numpy(), 'scores': preds['scores'][batch_idx].detach().cpu().numpy(),
'labels': preds['labels'][batch_idx].detach().cpu().numpy(), 'labels': preds['labels'][batch_idx].detach().cpu().numpy(),
} }
ret_dict_single.update(bbox_res) ret_dict_single.update(bbox_res)
# for gen results. # for gen results.
batch2seq = np.nonzero( batch2seq = np.nonzero(
preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0] preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_dict_single.update({ ret_dict_single.update({
'nline': len(batch2seq), 'nline': len(batch2seq),
'vectors': [] 'vectors': []
}) })
for i in batch2seq: for i in batch2seq:
pre = preds['polylines'][i].detach().cpu().numpy() pre = preds['polylines'][i].detach().cpu().numpy()
pre_msk = preds['polyline_masks'][i].detach().cpu().numpy() pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
valid_idx = np.nonzero(pre_msk)[0][:-1] valid_idx = np.nonzero(pre_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0) # From [200,1] to [199,0] to (1,0)
line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1) line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_dict_single['vectors'].append(line) ret_dict_single['vectors'].append(line)
# if gts is not None: # if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth # ret_dict_single['groundTruth'] = rec_groundtruth
ret_list.append(ret_dict_single) ret_list.append(ret_dict_single)
return ret_list return ret_list
def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2): def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2):
if 'keypoints' in gts['det']: if 'keypoints' in gts['det']:
gt_bbox = \ gt_bbox = \
gts['det']['keypoints'][batch_idx].detach().cpu().numpy() gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
else: else:
gt_bbox = \ gt_bbox = \
gts['det']['bbox'][batch_idx].detach().cpu().numpy() gts['det']['bbox'][batch_idx].detach().cpu().numpy()
det_gt = { det_gt = {
'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(), 'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(),
'bboxes': gt_bbox, 'bboxes': gt_bbox,
} }
batch2seq = np.nonzero( batch2seq = np.nonzero(
gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0] gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_groundtruth = { ret_groundtruth = {
'token': tokens[batch_idx], 'token': tokens[batch_idx],
'nline': len(batch2seq), 'nline': len(batch2seq),
'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(), 'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(),
'lines': [], 'lines': [],
} }
for i in batch2seq: for i in batch2seq:
gt_line =\ gt_line =\
gts['gen']['polylines'].detach().cpu().numpy()[i] gts['gen']['polylines'].detach().cpu().numpy()[i]
gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i] gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
if gen_net_name == 'gen_gmm': if gen_net_name == 'gen_gmm':
valid_idx = np.nonzero(gt_msk)[0] valid_idx = np.nonzero(gt_msk)[0]
else: else:
valid_idx = np.nonzero(gt_msk)[0][:-1] valid_idx = np.nonzero(gt_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0) # From [200,1] to [199,0] to (1,0)
line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1) line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_groundtruth['lines'].append(line) ret_groundtruth['lines'].append(line)
return det_gt, ret_groundtruth return det_gt, ret_groundtruth
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear from mmcv.cnn import Conv2d, Linear
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
from mmdet.core import (multi_apply, build_assigner, build_sampler, from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean) reduce_mean)
from mmdet.models import HEADS from mmdet.models import HEADS
from .detr_bbox import DETRBboxHead from .detr_bbox import DETRBboxHead
from mmdet.models.utils.transformer import inverse_sigmoid from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss from mmdet.models import build_loss
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class MapElementDetector(nn.Module): class MapElementDetector(nn.Module):
def __init__(self, def __init__(self,
canvas_size=(400, 200), canvas_size=(400, 200),
discrete_output=False, discrete_output=False,
separate_detect=False, separate_detect=False,
mode='xyxy', mode='xyxy',
bbox_size=None, bbox_size=None,
coord_dim=2, coord_dim=2,
kp_coord_dim=2, kp_coord_dim=2,
num_classes=3, num_classes=3,
in_channels=128, in_channels=128,
num_query=100, num_query=100,
max_lines=50, max_lines=50,
score_thre=0.2, score_thre=0.2,
num_reg_fcs=2, num_reg_fcs=2,
num_points=100, num_points=100,
iterative=False, iterative=False,
patch_size=None, patch_size=None,
sync_cls_avg_factor=True, sync_cls_avg_factor=True,
transformer: dict = None, transformer: dict = None,
positional_encoding: dict = None, positional_encoding: dict = None,
loss_cls: dict = None, loss_cls: dict = None,
loss_reg: dict = None, loss_reg: dict = None,
train_cfg: dict = None,): train_cfg: dict = None,):
super().__init__() super().__init__()
assigner = train_cfg['assigner'] assigner = train_cfg['assigner']
self.assigner = build_assigner(assigner) self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler # DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler') sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self) self.sampler = build_sampler(sampler_cfg, context=self)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.max_lines = max_lines self.max_lines = max_lines
self.score_thre = score_thre self.score_thre = score_thre
self.num_query = num_query self.num_query = num_query
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.num_points = num_points self.num_points = num_points
# branch # branch
# if loss_cls.use_sigmoid: # if loss_cls.use_sigmoid:
if loss_cls['use_sigmoid']: if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
else: else:
self.cls_out_channels = num_classes+1 self.cls_out_channels = num_classes+1
self.iterative = iterative self.iterative = iterative
self.num_reg_fcs = num_reg_fcs self.num_reg_fcs = num_reg_fcs
self._build_transformer(transformer, positional_encoding) self._build_transformer(transformer, positional_encoding)
# loss params # loss params
self.loss_cls = build_loss(loss_cls) self.loss_cls = build_loss(loss_cls)
self.bg_cls_weight = 0.1 self.bg_cls_weight = 0.1
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
self.bg_cls_weight = 0.0 self.bg_cls_weight = 0.0
self.sync_cls_avg_factor = sync_cls_avg_factor self.sync_cls_avg_factor = sync_cls_avg_factor
self.reg_loss = build_loss(loss_reg) self.reg_loss = build_loss(loss_reg)
self.separate_detect = separate_detect self.separate_detect = separate_detect
self.discrete_output = discrete_output self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2 self.bbox_size = 3 if mode=='sce' else 2
if bbox_size is not None: if bbox_size is not None:
self.bbox_size = bbox_size self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz self.coord_dim = coord_dim # for xyz
self.kp_coord_dim = kp_coord_dim self.kp_coord_dim = kp_coord_dim
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
# add reg, cls head for each decoder layer # add reg, cls head for each decoder layer
self._init_layers() self._init_layers()
self._init_branch() self._init_branch()
self.init_weights() self.init_weights()
self._init_embedding() self._init_embedding()
def _init_layers(self): def _init_layers(self):
"""Initialize some layer.""" """Initialize some layer."""
self.input_proj = Conv2d( self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1) self.in_channels, self.embed_dims, kernel_size=1)
# query_pos_embed & query_embed # query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query, self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims) self.embed_dims)
def _init_embedding(self): def _init_embedding(self):
self.label_embed = nn.Embedding( self.label_embed = nn.Embedding(
self.num_classes, self.embed_dims) self.num_classes, self.embed_dims)
self.img_coord_embed = nn.Linear(2, self.embed_dims) self.img_coord_embed = nn.Linear(2, self.embed_dims)
# query_pos_embed & query_embed # query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query, self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims*2) self.embed_dims*2)
# for bbox parameter xstart, ystart, xend, yend # for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding( self.bbox_size, self.bbox_embedding = nn.Embedding( self.bbox_size,
self.embed_dims*2) self.embed_dims*2)
def _init_branch(self,): def _init_branch(self,):
"""Initialize classification branch and regression branch of head.""" """Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels) fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels) # fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = [] reg_branch = []
for _ in range(self.num_reg_fcs): for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims)) reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU()) reg_branch.append(nn.ReLU())
if self.discrete_output: if self.discrete_output:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,)) self.embed_dims, max(self.canvas_size), bias=True,))
else: else:
reg_branch.append(nn.Linear( reg_branch.append(nn.Linear(
self.embed_dims, self.coord_dim, bias=True,)) self.embed_dims, self.coord_dim, bias=True,))
reg_branch = nn.Sequential(*reg_branch) reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not # add sigmoid or not
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers num_pred = self.transformer.decoder.num_layers
if self.iterative: if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred) fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred) reg_branch = _get_clones(reg_branch, num_pred)
else: else:
reg_branch = nn.ModuleList( reg_branch = nn.ModuleList(
[reg_branch for _ in range(num_pred)]) [reg_branch for _ in range(num_pred)])
fc_cls = nn.ModuleList( fc_cls = nn.ModuleList(
[fc_cls for _ in range(num_pred)]) [fc_cls for _ in range(num_pred)])
self.pre_branches = nn.ModuleDict([ self.pre_branches = nn.ModuleDict([
('cls', fc_cls), ('cls', fc_cls),
('reg', reg_branch), ]) ('reg', reg_branch), ])
def init_weights(self): def init_weights(self):
"""Initialize weights of the DeformDETR head.""" """Initialize weights of the DeformDETR head."""
for p in self.input_proj.parameters(): for p in self.input_proj.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
self.transformer.init_weights() self.transformer.init_weights()
# init prediction branch # init prediction branch
for k, v in self.pre_branches.items(): for k, v in self.pre_branches.items():
for param in v.parameters(): for param in v.parameters():
if param.dim() > 1: if param.dim() > 1:
nn.init.xavier_uniform_(param) nn.init.xavier_uniform_(param)
# focal loss init # focal loss init
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01) bias_init = bias_init_with_prob(0.01)
# for last layer # for last layer
if isinstance(self.pre_branches['cls'], nn.ModuleList): if isinstance(self.pre_branches['cls'], nn.ModuleList):
for m in self.pre_branches['cls']: for m in self.pre_branches['cls']:
nn.init.constant_(m.bias, bias_init) nn.init.constant_(m.bias, bias_init)
else: else:
m = self.pre_branches['cls'] m = self.pre_branches['cls']
nn.init.constant_(m.bias, bias_init) nn.init.constant_(m.bias, bias_init)
def _build_transformer(self, transformer, positional_encoding): def _build_transformer(self, transformer, positional_encoding):
# transformer # transformer
self.act_cfg = transformer.get('act_cfg', self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True)) dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg) self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding( self.positional_encoding = build_positional_encoding(
positional_encoding) positional_encoding)
self.transformer = build_transformer(transformer) self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims self.embed_dims = self.transformer.embed_dims
def _prepare_context(self, context): def _prepare_context(self, context):
"""Prepare class label and vertex context.""" """Prepare class label and vertex context."""
global_context_embedding = None global_context_embedding = None
image_embeddings = context['bev_embeddings'] image_embeddings = context['bev_embeddings']
image_embeddings = self.input_proj( image_embeddings = self.input_proj(
image_embeddings) # only change feature size image_embeddings) # only change feature size
# Pass images through encoder # Pass images through encoder
device = image_embeddings.device device = image_embeddings.device
# Add 2D coordinate grid embedding # Add 2D coordinate grid embedding
B, C, H, W = image_embeddings.shape B, C, H, W = image_embeddings.shape
Ws = torch.linspace(-1., 1., W) Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H) Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack( image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device) torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords) image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2) image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence # Reshape spatial grid to sequence
sequential_context_embeddings = image_embeddings.reshape( sequential_context_embeddings = image_embeddings.reshape(
B, C, H, W) B, C, H, W)
return (global_context_embedding, sequential_context_embeddings) return (global_context_embedding, sequential_context_embeddings)
def forward(self, context, img_metas=None, multi_scale=False): def forward(self, context, img_metas=None, multi_scale=False):
''' '''
Args: Args:
bev_feature (List[Tensor]): shape [B, C, H, W] bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view feature in bev view
img_metas img_metas
Outs: Outs:
preds_dict (Dict): preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
''' '''
(global_context_embedding, sequential_context_embeddings) =\ (global_context_embedding, sequential_context_embeddings) =\
self._prepare_context(context) self._prepare_context(context)
x = sequential_context_embeddings x = sequential_context_embeddings
B, C, H, W = x.shape B, C, H, W = x.shape
query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1) query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1)
bbox_embed = self.bbox_embedding.weight bbox_embed = self.bbox_embedding.weight
query_embedding = query_embedding + bbox_embed[None,None] query_embedding = query_embedding + bbox_embed[None,None]
query_embedding = query_embedding.view(B, -1, C*2) query_embedding = query_embedding.view(B, -1, C*2)
img_masks = x.new_zeros((B, H, W)) img_masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(img_masks) pos_embed = self.positional_encoding(img_masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim] # outs_dec: [nb_dec, bs, num_query, embed_dim]
hs, init_reference, inter_references = self.transformer( hs, init_reference, inter_references = self.transformer(
[x,], [x,],
[img_masks.type(torch.bool)], [img_masks.type(torch.bool)],
query_embedding, query_embedding,
[pos_embed], [pos_embed],
reg_branches= self.reg_branches if self.iterative else None, # noqa:E501 reg_branches= self.reg_branches if self.iterative else None, # noqa:E501
cls_branches= None, # noqa:E501 cls_branches= None, # noqa:E501
) )
outs_dec = hs.permute(0, 2, 1, 3) outs_dec = hs.permute(0, 2, 1, 3)
outputs = [] outputs = []
for i, (query_feat) in enumerate(outs_dec): for i, (query_feat) in enumerate(outs_dec):
if i == 0: if i == 0:
reference = init_reference reference = init_reference
else: else:
reference = inter_references[i - 1] reference = inter_references[i - 1]
outputs.append(self.get_prediction(i,query_feat,reference)) outputs.append(self.get_prediction(i,query_feat,reference))
return outputs return outputs
def get_prediction(self, level, query_feat, reference): def get_prediction(self, level, query_feat, reference):
bs, num_query, h = query_feat.shape bs, num_query, h = query_feat.shape
query_feat = query_feat.view(bs, -1, self.bbox_size,h) query_feat = query_feat.view(bs, -1, self.bbox_size,h)
ocls = self.pre_branches['cls'][level](query_feat.flatten(-2)) ocls = self.pre_branches['cls'][level](query_feat.flatten(-2))
# ocls = ocls.mean(-2) # ocls = ocls.mean(-2)
reference = inverse_sigmoid(reference) reference = inverse_sigmoid(reference)
reference = reference.view(bs, -1, self.bbox_size,self.coord_dim) reference = reference.view(bs, -1, self.bbox_size,self.coord_dim)
tmp = self.pre_branches['reg'][level](query_feat) tmp = self.pre_branches['reg'][level](query_feat)
tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim] tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim]
lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2 lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2
lines = lines * self.canvas_size[:self.coord_dim] lines = lines * self.canvas_size[:self.coord_dim]
lines = lines.flatten(-2) lines = lines.flatten(-2)
return dict( return dict(
lines=lines, # [bs, num_query, bboxsize*2] lines=lines, # [bs, num_query, bboxsize*2]
scores=ocls, # [bs, num_query, num_class] scores=ocls, # [bs, num_query, num_class]
embeddings= query_feat, # [bs, num_query, bbox_size, h] embeddings= query_feat, # [bs, num_query, bbox_size, h]
) )
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines')) @force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self, def _get_target_single(self,
score_pred, score_pred,
lines_pred, lines_pred,
gt_labels, gt_labels,
gt_lines, gt_lines,
gt_bboxes_ignore=None): gt_bboxes_ignore=None):
""" """
Compute regression and classification targets for one image. Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_score (Tensor): Box score logits from a single decoder layer cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels]. for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor): lines_pred (Tensor):
shape [num_query, num_points, 2]. shape [num_query, num_points, 2].
gt_lines (Tensor): gt_lines (Tensor):
shape [num_gt, num_points, 2]. shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor) gt_labels (torch.LongTensor)
shape [num_gt, ] shape [num_gt, ]
Returns: Returns:
tuple[Tensor]: a tuple containing the following for one image. tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image. - labels (LongTensor): Labels of each image.
shape [num_query, 1] shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image. - label_weights (Tensor]): Label weights of each image.
shape [num_query, 1] shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image. - lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image. - lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2] shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image. - pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image. - neg_inds (Tensor): Sampled negative indices for each image.
""" """
num_pred_lines = len(lines_pred) num_pred_lines = len(lines_pred)
# assigner and sampler # assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,), assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines, gts=dict(lines=gt_lines,
labels=gt_labels, ), labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample( sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines) assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds neg_inds = sampling_result.neg_inds
pos_gt_inds = sampling_result.pos_assigned_gt_inds pos_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets 0: foreground, 1: background # label targets 0: foreground, 1: background
if self.separate_detect: if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long) labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long)
else: else:
labels = gt_lines.new_full( labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long) (num_pred_lines, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines) label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets since lines_pred's last dimension is the vocabulary # bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension. # and ground truth dose not have this dimension.
if self.discrete_output: if self.discrete_output:
lines_target = torch.zeros_like(lines_pred[..., 0]).long() lines_target = torch.zeros_like(lines_pred[..., 0]).long()
lines_weights = torch.zeros_like(lines_pred[..., 0]) lines_weights = torch.zeros_like(lines_pred[..., 0])
else: else:
lines_target = torch.zeros_like(lines_pred) lines_target = torch.zeros_like(lines_pred)
lines_weights = torch.zeros_like(lines_pred) lines_weights = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type( lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type(
lines_target.dtype) lines_target.dtype)
lines_weights[pos_inds] = 1.0 lines_weights[pos_inds] = 1.0
n = lines_weights.sum(-1, keepdim=True) n = lines_weights.sum(-1, keepdim=True)
lines_weights = lines_weights / n.masked_fill(n == 0, 1) lines_weights = lines_weights / n.masked_fill(n == 0, 1)
return (labels, label_weights, lines_target, lines_weights, return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds, pos_gt_inds) pos_inds, neg_inds, pos_gt_inds)
# @force_fp32(apply_to=('preds', 'gts')) # @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None): def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
""" """
Compute regression and classification targets for a batch image. Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used. Outputs from a single decoder layer of a single feature level are used.
Args: Args:
cls_scores_list (list[Tensor]): Box score logits from a single cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query, decoder layer for each image with shape [num_query,
cls_out_channels]. cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2]. lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
tuple: a tuple containing the following targets. tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images. - labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \ - label_weights_list (list[Tensor]): Label weights for all \
images. images.
- lines_targets_list (list[Tensor]): Lines targets for all \ - lines_targets_list (list[Tensor]): Lines targets for all \
images. images.
- lines_weight_list (list[Tensor]): Lines weights for all \ - lines_weight_list (list[Tensor]): Lines weights for all \
images. images.
- num_total_pos (int): Number of positive samples in all \ - num_total_pos (int): Number of positive samples in all \
images. images.
- num_total_neg (int): Number of negative samples in all \ - num_total_neg (int): Number of negative samples in all \
images. images.
""" """
assert gt_bboxes_ignore_list is None, \ assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.' 'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs # format the inputs
if self.separate_detect: if self.separate_detect:
bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])] bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])]
class_label = torch.zeros_like(gts['bbox_mask']).long() class_label = torch.zeros_like(gts['bbox_mask']).long()
class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])] class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])]
else: else:
class_label = gts['class_label'] class_label = gts['class_label']
bbox = gts['bbox'] bbox = gts['bbox']
if self.discrete_output: if self.discrete_output:
lines_pred = preds['lines'].logits lines_pred = preds['lines'].logits
else: else:
lines_pred = preds['lines'] lines_pred = preds['lines']
bbox = [b.float() for b in bbox] bbox = [b.float() for b in bbox]
(labels_list, label_weights_list, (labels_list, label_weights_list,
lines_targets_list, lines_weights_list, lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply( pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply(
self._get_target_single, self._get_target_single,
preds['scores'], lines_pred, preds['scores'], lines_pred,
class_label, bbox, class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list) gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict( new_gts = dict(
labels=labels_list, labels=labels_list,
label_weights=label_weights_list, label_weights=label_weights_list,
bboxs=lines_targets_list, bboxs=lines_targets_list,
bboxs_weights=lines_weights_list, bboxs_weights=lines_weights_list,
) )
return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts')) # @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self, def loss_single(self,
preds: dict, preds: dict,
gts: dict, gts: dict,
gt_bboxes_ignore_list=None, gt_bboxes_ignore_list=None,
reduction='none'): reduction='none'):
""" """
Loss function for outputs from a single decoder layer of a single Loss function for outputs from a single decoder layer of a single
feature level. feature level.
Args: Args:
cls_scores (Tensor): Box score logits from a single decoder layer cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels]. for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor): lines_preds (Tensor):
shape [bs, num_query, num_points, 2]. shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]): gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None. boxes which can be ignored for each image. Default None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer. a single decoder layer.
""" """
# Get target for each sample # Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\ new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list) self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data # Batched all data
for k, v in new_gts.items(): for k, v in new_gts.items():
new_gts[k] = torch.stack(v, dim=0) new_gts[k] = torch.stack(v, dim=0)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor])) preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1) cls_avg_factor = max(cls_avg_factor, 1)
# Classification loss # Classification loss
if self.separate_detect: if self.separate_detect:
loss_cls = self.bce_loss( loss_cls = self.bce_loss(
preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor) preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor)
else: else:
# since the inputs needs the second dim is the class dim, we permute the prediction. # since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels) cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
cls_labels = new_gts['labels'].reshape(-1) cls_labels = new_gts['labels'].reshape(-1)
cls_weights = new_gts['label_weights'].reshape(-1) cls_weights = new_gts['label_weights'].reshape(-1)
loss_cls = self.loss_cls( loss_cls = self.loss_cls(
cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor) cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for # Compute the average number of gt boxes accross all gpus, for
# normalization purposes # normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# position NLL loss # position NLL loss
if self.discrete_output: if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) * loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos) new_gts['bboxs_weights']).sum()/(num_total_pos)
else: else:
loss_reg = self.reg_loss( loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos) preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
loss_dict = dict( loss_dict = dict(
cls=loss_cls, cls=loss_cls,
reg=loss_reg, reg=loss_reg,
) )
return loss_dict, pos_inds_list, pos_gt_inds_list return loss_dict, pos_inds_list, pos_gt_inds_list
@force_fp32(apply_to=('gt_lines_list', 'preds_dicts')) @force_fp32(apply_to=('gt_lines_list', 'preds_dicts'))
def loss(self, def loss(self,
gts: dict, gts: dict,
preds_dicts: dict, preds_dicts: dict,
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
reduction='mean'): reduction='mean'):
""" """
Loss Function. Loss Function.
Args: Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2) with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ). image with shape (num_gts, ).
preds_dicts: preds_dicts:
all_cls_scores (Tensor): Classification score of all all_cls_scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor): all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2]. [nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None. which can be ignored for each image. Default None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
""" """
assert gt_bboxes_ignore is None, \ assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \ f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.' f'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer # Since there might have multi layer
losses, pos_inds_lists, pos_gt_inds_lists = multi_apply( losses, pos_inds_lists, pos_gt_inds_lists = multi_apply(
self.loss_single, self.loss_single,
preds_dicts, preds_dicts,
gts=gts, gts=gts,
gt_bboxes_ignore_list=gt_bboxes_ignore, gt_bboxes_ignore_list=gt_bboxes_ignore,
reduction=reduction) reduction=reduction)
# Format the losses # Format the losses
loss_dict = dict() loss_dict = dict()
# loss from the last decoder layer # loss from the last decoder layer
for k, v in losses[-1].items(): for k, v in losses[-1].items():
loss_dict[k] = v loss_dict[k] = v
# Loss from other decoder layers # Loss from other decoder layers
num_dec_layer = 0 num_dec_layer = 0
for loss in losses[:-1]: for loss in losses[:-1]:
for k, v in loss.items(): for k, v in loss.items():
loss_dict[f'd{num_dec_layer}.{k}'] = v loss_dict[f'd{num_dec_layer}.{k}'] = v
num_dec_layer += 1 num_dec_layer += 1
return loss_dict, pos_inds_lists, pos_gt_inds_lists return loss_dict, pos_inds_lists, pos_gt_inds_lists
def post_process(self, preds_dicts: list, **kwargs): def post_process(self, preds_dicts: list, **kwargs):
''' '''
Args: Args:
preds_dicts: preds_dicts:
scores (Tensor): Classification score of all scores (Tensor): Classification score of all
decoder layers, has shape decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels]. [nb_dec, bs, num_query, cls_out_channels].
lines (Tensor): lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)]. [nb_dec, bs, num_query, bbox parameters(4)].
Outs: Outs:
ret_list (List[Dict]) with length as bs ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch list of result dict for each sample in the batch
XXX XXX
''' '''
preds = preds_dicts[-1] preds = preds_dicts[-1]
batched_cls_scores = preds['scores'] batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines'] batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0) batch_size = batched_cls_scores.size(0)
device = batched_cls_scores.device device = batched_cls_scores.device
result_dict = { result_dict = {
'bbox': [], 'bbox': [],
'scores': [], 'scores': [],
'labels': [], 'labels': [],
'bbox_flat': [], 'bbox_flat': [],
'lines_cls': [], 'lines_cls': [],
'lines_bs_idx': [], 'lines_bs_idx': [],
} }
for i in range(batch_size): for i in range(batch_size):
cls_scores = batched_cls_scores[i] cls_scores = batched_cls_scores[i]
det_preds = batched_lines_preds[i] det_preds = batched_lines_preds[i]
max_num = self.max_lines max_num = self.max_lines
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid() cls_scores = cls_scores.sigmoid()
scores, valid_idx = cls_scores.view(-1).topk(max_num) scores, valid_idx = cls_scores.view(-1).topk(max_num)
det_labels = valid_idx % self.num_classes det_labels = valid_idx % self.num_classes
valid_idx = valid_idx // self.num_classes valid_idx = valid_idx // self.num_classes
det_preds = det_preds[valid_idx] det_preds = det_preds[valid_idx]
else: else:
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1) scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
scores, valid_idx = scores.topk(max_num) scores, valid_idx = scores.topk(max_num)
det_preds = det_preds[valid_idx] det_preds = det_preds[valid_idx]
det_labels = det_labels[valid_idx] det_labels = det_labels[valid_idx]
nline = len(valid_idx) nline = len(valid_idx)
result_dict['bbox'].append(det_preds) result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores) result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels) result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline) result_dict['lines_bs_idx'].extend([i]*nline)
# for down stream polyline # for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0) _bboxs = torch.cat(result_dict['bbox'], dim=0)
# quantize the data # quantize the data
result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32) result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32)
result_dict['lines_cls'] = torch.cat( result_dict['lines_cls'] = torch.cat(
result_dict['labels'], dim=0).long() result_dict['labels'], dim=0).long()
result_dict['lines_bs_idx'] = torch.tensor( result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long() result_dict['lines_bs_idx'], device=device).long()
return result_dict return result_dict
\ No newline at end of file
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
from mmdet.models import HEADS from mmdet.models import HEADS
from .detgen_utils.causal_trans import (CausalTransformerDecoder, from .detgen_utils.causal_trans import (CausalTransformerDecoder,
CausalTransformerDecoderLayer) CausalTransformerDecoderLayer)
from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask, from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask,
quantize_verts, top_k_logits, top_p_logits) quantize_verts, top_k_logits, top_p_logits)
from mmcv.runner import force_fp32, auto_fp16 from mmcv.runner import force_fp32, auto_fp16
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class PolylineGenerator(nn.Module): class PolylineGenerator(nn.Module):
""" """
Autoregressive generative model of n-gon meshes. Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens: new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP] [f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder. Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators, Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices. distribution over vertex indices.
""" """
def __init__(self, def __init__(self,
in_channels, in_channels,
encoder_config, encoder_config,
decoder_config, decoder_config,
class_conditional=True, class_conditional=True,
num_classes=55, num_classes=55,
decoder_cross_attention=True, decoder_cross_attention=True,
use_discrete_vertex_embeddings=True, use_discrete_vertex_embeddings=True,
condition_points_num=3, condition_points_num=3,
coord_dim=2, coord_dim=2,
canvas_size=(400, 200), canvas_size=(400, 200),
max_seq_length=500, max_seq_length=500,
name='gen_model'): name='gen_model'):
"""Initializes FaceModel. """Initializes FaceModel.
Args: Args:
encoder_config: Dictionary with TransformerEncoder config. encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config. decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings. class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on. num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs. querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings. use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position max_seq_length: Maximum face sequence length. Used for learned position
embeddings. embeddings.
name: Name of variable scope name: Name of variable scope
""" """
super(PolylineGenerator, self).__init__() super(PolylineGenerator, self).__init__()
self.embedding_dim = decoder_config['layer_config']['d_model'] self.embedding_dim = decoder_config['layer_config']['d_model']
self.class_conditional = class_conditional self.class_conditional = class_conditional
self.num_classes = num_classes self.num_classes = num_classes
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.decoder_cross_attention = decoder_cross_attention self.decoder_cross_attention = decoder_cross_attention
self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings
self.condition_points_num = condition_points_num self.condition_points_num = condition_points_num
self.fp16_enabled = False self.fp16_enabled = False
self.coord_dim = coord_dim # if we use xyz else 2 when we use xy self.coord_dim = coord_dim # if we use xyz else 2 when we use xy
self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
# initialize the model # initialize the model
self._project_to_logits = nn.Linear( self._project_to_logits = nn.Linear(
self.embedding_dim, self.embedding_dim,
max(canvas_size) + 1, # + 1 for stopping token. use_bias=True, max(canvas_size) + 1, # + 1 for stopping token. use_bias=True,
) )
self.input_proj = nn.Conv2d( self.input_proj = nn.Conv2d(
in_channels, self.embedding_dim, kernel_size=1) in_channels, self.embedding_dim, kernel_size=1)
decoder_layer = CausalTransformerDecoderLayer( decoder_layer = CausalTransformerDecoderLayer(
**decoder_config.pop('layer_config')) **decoder_config.pop('layer_config'))
self.decoder = CausalTransformerDecoder( self.decoder = CausalTransformerDecoder(
decoder_layer, **decoder_config) decoder_layer, **decoder_config)
self._init_embedding() self._init_embedding()
self.init_weights() self.init_weights()
def _init_embedding(self): def _init_embedding(self):
if self.class_conditional: if self.class_conditional:
self.label_embed = nn.Embedding( self.label_embed = nn.Embedding(
self.num_classes, self.embedding_dim) self.num_classes, self.embedding_dim)
self.coord_embed = nn.Embedding(self.coord_dim, self.embedding_dim) self.coord_embed = nn.Embedding(self.coord_dim, self.embedding_dim)
self.pos_embeddings = nn.Embedding( self.pos_embeddings = nn.Embedding(
self.max_seq_length, self.embedding_dim) self.max_seq_length, self.embedding_dim)
# to indicate the role of the position is the start of the line or the end of it. # to indicate the role of the position is the start of the line or the end of it.
self.bbox_context_embed = \ self.bbox_context_embed = \
nn.Embedding(self.condition_points_num, self.embedding_dim) nn.Embedding(self.condition_points_num, self.embedding_dim)
self.img_coord_embed = nn.Linear(2, self.embedding_dim) self.img_coord_embed = nn.Linear(2, self.embedding_dim)
# initialize the verteices embedding # initialize the verteices embedding
if self.use_discrete_vertex_embeddings: if self.use_discrete_vertex_embeddings:
self.vertex_embed = nn.Embedding( self.vertex_embed = nn.Embedding(
max(self.canvas_size) + 1, self.embedding_dim) max(self.canvas_size) + 1, self.embedding_dim)
else: else:
self.vertex_embed = nn.Linear(1, self.embedding_dim) self.vertex_embed = nn.Linear(1, self.embedding_dim)
def init_weights(self): def init_weights(self):
for p in self.parameters(): for p in self.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def _embed_kps(self, bbox): def _embed_kps(self, bbox):
bbox_len = bbox.shape[-1] bbox_len = bbox.shape[-1]
# Bbox_context # Bbox_context
bbox_embedding = self.bbox_context_embed( bbox_embedding = self.bbox_context_embed(
(torch.arange(bbox_len, device=bbox.device) / self.kp_coord_dim).floor().long()) (torch.arange(bbox_len, device=bbox.device) / self.kp_coord_dim).floor().long())
# Coord indicators (x, y) # Coord indicators (x, y)
coord_embeddings = self.coord_embed( coord_embeddings = self.coord_embed(
torch.arange(bbox_len, device=bbox.device) % self.kp_coord_dim) torch.arange(bbox_len, device=bbox.device) % self.kp_coord_dim)
# Discrete vertex value embeddings # Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(bbox) vert_embeddings = self.vertex_embed(bbox)
return vert_embeddings + (bbox_embedding+coord_embeddings)[None] return vert_embeddings + (bbox_embedding+coord_embeddings)[None]
def _prepare_context(self, batch, context): def _prepare_context(self, batch, context):
"""Prepare class label and vertex context.""" """Prepare class label and vertex context."""
global_context_embedding = None global_context_embedding = None
if self.class_conditional: if self.class_conditional:
global_context_embedding = self.label_embed(batch['lines_cls']) global_context_embedding = self.label_embed(batch['lines_cls'])
bbox_embeddings = self._embed_kps(batch['bbox_flat']) bbox_embeddings = self._embed_kps(batch['bbox_flat'])
if global_context_embedding is not None: if global_context_embedding is not None:
global_context_embedding = torch.cat( global_context_embedding = torch.cat(
[global_context_embedding[:, None], bbox_embeddings], dim=1) [global_context_embedding[:, None], bbox_embeddings], dim=1)
# Pass images through encoder # Pass images through encoder
image_embeddings = assign_bev( image_embeddings = assign_bev(
context['bev_embeddings'], batch['lines_bs_idx']) context['bev_embeddings'], batch['lines_bs_idx'])
image_embeddings = self.input_proj(image_embeddings) image_embeddings = self.input_proj(image_embeddings)
device = image_embeddings.device device = image_embeddings.device
# Add 2D coordinate grid embedding # Add 2D coordinate grid embedding
H, W = image_embeddings.shape[2:] H, W = image_embeddings.shape[2:]
Ws = torch.linspace(-1., 1., W) Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H) Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack( image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device) torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords) image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2) image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence # Reshape spatial grid to sequence
B = image_embeddings.shape[0] B = image_embeddings.shape[0]
sequential_context_embeddings = image_embeddings.reshape( sequential_context_embeddings = image_embeddings.reshape(
B, self.embedding_dim, -1).permute(0, 2, 1) B, self.embedding_dim, -1).permute(0, 2, 1)
return (global_context_embedding, return (global_context_embedding,
sequential_context_embeddings) sequential_context_embeddings)
def _embed_inputs(self, seqs, condition_embedding=None): def _embed_inputs(self, seqs, condition_embedding=None):
"""Embeds face sequences and adds within and between face positions. """Embeds face sequences and adds within and between face positions.
Args: Args:
seq: B, seqlen=vlen*3, seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns: Returns:
embeddings: B, seqlen, h embeddings: B, seqlen, h
""" """
B, seq_len = seqs.shape[:2] B, seq_len = seqs.shape[:2]
# Position embeddings # Position embeddings
pos_embeddings = self.pos_embeddings( pos_embeddings = self.pos_embeddings(
(torch.arange(seq_len, device=seqs.device) / self.coord_dim).floor().long()) # seq_len, h (torch.arange(seq_len, device=seqs.device) / self.coord_dim).floor().long()) # seq_len, h
# Coord indicators (x, y, z(optional)) # Coord indicators (x, y, z(optional))
coord_embeddings = self.coord_embed( coord_embeddings = self.coord_embed(
torch.arange(seq_len, device=seqs.device) % self.coord_dim) torch.arange(seq_len, device=seqs.device) % self.coord_dim)
# Discrete vertex value embeddings # Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(seqs) vert_embeddings = self.vertex_embed(seqs)
# Aggregate embeddings # Aggregate embeddings
embeddings = vert_embeddings + \ embeddings = vert_embeddings + \
(coord_embeddings+pos_embeddings)[None] (coord_embeddings+pos_embeddings)[None]
embeddings = torch.cat([condition_embedding, embeddings], dim=1) embeddings = torch.cat([condition_embedding, embeddings], dim=1)
return embeddings return embeddings
def forward(self, batch: dict, **kwargs): def forward(self, batch: dict, **kwargs):
""" """
Pass batch through face model and get log probabilities. Pass batch through face model and get log probabilities.
Args: Args:
batch: Dictionary containing: batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3]. 'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened 'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces. faces.
'vertices_mask': float32 tensor with shape 'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'. [batch_size, num_vertices] that masks padded elements in 'vertices'.
""" """
if self.training: if self.training:
return self.forward_train(batch, **kwargs) return self.forward_train(batch, **kwargs)
else: else:
return self.inference(batch, **kwargs) return self.inference(batch, **kwargs)
def sperate_forward(self, batch, context, **kwargs): def sperate_forward(self, batch, context, **kwargs):
polyline_length = batch['polyline_masks'].sum(-1) polyline_length = batch['polyline_masks'].sum(-1)
c1, c2, revert_idx, size = get_chunk_idx(polyline_length) c1, c2, revert_idx, size = get_chunk_idx(polyline_length)
sizes = [size, polyline_length.max()] sizes = [size, polyline_length.max()]
polyline_logits = [] polyline_logits = []
for c_idx, size in zip([c1,c2], sizes): for c_idx, size in zip([c1,c2], sizes):
new_batch = assign_batch(batch,c_idx, size) new_batch = assign_batch(batch,c_idx, size)
_poly_logits = self._forward_train(new_batch,context,**kwargs) _poly_logits = self._forward_train(new_batch,context,**kwargs)
polyline_logits.append(_poly_logits) polyline_logits.append(_poly_logits)
# maybe imporve the speed # maybe imporve the speed
for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)): for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)):
if size < sizes[1]: if size < sizes[1]:
_poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0) _poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0)
polyline_logits[i] = _poly_logits polyline_logits[i] = _poly_logits
polyline_logits = torch.cat(polyline_logits,0) polyline_logits = torch.cat(polyline_logits,0)
polyline_logits = polyline_logits[revert_idx] polyline_logits = polyline_logits[revert_idx]
cat_dist = Categorical(logits=polyline_logits) cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist} return {'polylines':cat_dist}
def forward_train(self, batch: dict, context: dict, **kwargs): def forward_train(self, batch: dict, context: dict, **kwargs):
""" """
Returns: Returns:
pred_dist: Categorical predictive distribution with batch shape pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length]. [batch_size, seq_length].
""" """
# we use the gt vertices # we use the gt vertices
if False: if False:
polyline_logits = self._forward_train(batch, context, **kwargs) polyline_logits = self._forward_train(batch, context, **kwargs)
cat_dist = Categorical(logits=polyline_logits) cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist} return {'polylines':cat_dist}
else: else:
return self.sperate_forward(batch, context, **kwargs) return self.sperate_forward(batch, context, **kwargs)
def _forward_train(self, batch: dict, context: dict, **kwargs): def _forward_train(self, batch: dict, context: dict, **kwargs):
""" """
Returns: Returns:
pred_dist: Categorical predictive distribution with batch shape pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length]. [batch_size, seq_length].
""" """
# we use the gt vertices # we use the gt vertices
global_context, seq_context = self._prepare_context( global_context, seq_context = self._prepare_context(
batch, context) batch, context)
logits = self.body( logits = self.body(
# Last element not used for preds # Last element not used for preds
batch['polylines'][:, :-1], batch['polylines'][:, :-1],
global_context_embedding=global_context, global_context_embedding=global_context,
sequential_context_embeddings=seq_context, sequential_context_embeddings=seq_context,
return_logits=True, return_logits=True,
is_training=self.training) is_training=self.training)
return logits return logits
@force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache')) @force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache'))
def body(self, def body(self,
seqs, seqs,
global_context_embedding=None, global_context_embedding=None,
sequential_context_embeddings=None, sequential_context_embeddings=None,
temperature=1., temperature=1.,
top_k=0, top_k=0,
top_p=1., top_p=1.,
cache=None, cache=None,
return_logits=False, return_logits=False,
is_training=True): is_training=True):
""" """
Outputs categorical dist for vertex indices. Outputs categorical dist for vertex indices.
Body of the face model Body of the face model
""" """
# Embed inputs # Embed inputs
condition_len = global_context_embedding.shape[1] condition_len = global_context_embedding.shape[1]
decoder_inputs = self._embed_inputs( decoder_inputs = self._embed_inputs(
seqs, global_context_embedding) seqs, global_context_embedding)
# Pass through Transformer decoder # Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting. # since our memory efficient decoder only support seq first setting.
decoder_inputs = decoder_inputs.transpose(0, 1) decoder_inputs = decoder_inputs.transpose(0, 1)
if sequential_context_embeddings is not None: if sequential_context_embeddings is not None:
sequential_context_embeddings = sequential_context_embeddings.transpose( sequential_context_embeddings = sequential_context_embeddings.transpose(
0, 1) 0, 1)
causal_msk = None causal_msk = None
if is_training: if is_training:
causal_msk = generate_square_subsequent_mask( causal_msk = generate_square_subsequent_mask(
decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device) decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device)
decoder_outputs, cache = self.decoder( decoder_outputs, cache = self.decoder(
tgt=decoder_inputs, tgt=decoder_inputs,
cache=cache, cache=cache,
memory=sequential_context_embeddings, memory=sequential_context_embeddings,
causal_mask=causal_msk, causal_mask=causal_msk,
) )
decoder_outputs = decoder_outputs.transpose(0, 1) decoder_outputs = decoder_outputs.transpose(0, 1)
# since we only need the predict seq # since we only need the predict seq
decoder_outputs = decoder_outputs[:, condition_len-1:] decoder_outputs = decoder_outputs[:, condition_len-1:]
# Get logits and optionally process for sampling # Get logits and optionally process for sampling
logits = self._project_to_logits(decoder_outputs) logits = self._project_to_logits(decoder_outputs)
# y mask # y mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device) _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_y = (_vert_mask < self.canvas_size[1]+1) vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
vertices_mask_y[0] = False # y position doesn't have stop sign vertices_mask_y[0] = False # y position doesn't have stop sign
logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \ logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \
vertices_mask_y - ~vertices_mask_y*1e9 vertices_mask_y - ~vertices_mask_y*1e9
if self.coord_dim > 2: if self.coord_dim > 2:
# z mask # z mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device) _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_z = (_vert_mask < self.canvas_size[2]+1) vertices_mask_z = (_vert_mask < self.canvas_size[2]+1)
vertices_mask_z[0] = False # y position doesn't have stop sign vertices_mask_z[0] = False # y position doesn't have stop sign
logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \ logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \
vertices_mask_z - ~vertices_mask_z*1e9 vertices_mask_z - ~vertices_mask_z*1e9
logits = logits/temperature logits = logits/temperature
logits = top_k_logits(logits, top_k) logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p) logits = top_p_logits(logits, top_p)
if return_logits: if return_logits:
return logits return logits
cat_dist = Categorical(logits=logits) cat_dist = Categorical(logits=logits)
return cat_dist, cache return cat_dist, cache
@force_fp32(apply_to=('pred')) @force_fp32(apply_to=('pred'))
def loss(self, gt: dict, pred: dict): def loss(self, gt: dict, pred: dict):
weight = gt['polyline_weights'] weight = gt['polyline_weights']
mask = gt['polyline_masks'] mask = gt['polyline_masks']
loss = -torch.sum( loss = -torch.sum(
pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum() pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum()
return {'seq': loss} return {'seq': loss}
def inference(self, def inference(self,
batch: dict, batch: dict,
context: dict, context: dict,
max_sample_length=None, max_sample_length=None,
temperature=1., temperature=1.,
top_k=0, top_k=0,
top_p=1., top_p=1.,
only_return_complete=False, only_return_complete=False,
gt_condition=False, gt_condition=False,
**kwargs): **kwargs):
"""Sample from face model using caching. """Sample from face model using caching.
Args: Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'. context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details. See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated. that do not complete are truncated.
temperature: Scalar softmax temperature > 0. temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling. top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling. top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator. return all samples along with completed indicator.
Returns: Returns:
outputs: Output dictionary with fields: outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then 'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length. corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3]. 'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each 'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples. example in padded vertex samples.
""" """
# prepare the conditional variable # prepare the conditional variable
global_context, seq_context = self._prepare_context( global_context, seq_context = self._prepare_context(
batch, context) batch, context)
device = global_context.device device = global_context.device
batch_size = global_context.shape[0] batch_size = global_context.shape[0]
# While loop sampling with caching # While loop sampling with caching
samples = torch.empty( samples = torch.empty(
[batch_size, 0], dtype=torch.int32, device=device) [batch_size, 0], dtype=torch.int32, device=device)
max_sample_length = max_sample_length or self.max_seq_length max_sample_length = max_sample_length or self.max_seq_length
seq_len = max_sample_length*self.coord_dim+1 seq_len = max_sample_length*self.coord_dim+1
cache = None cache = None
decoded_tokens = \ decoded_tokens = \
torch.zeros((batch_size,seq_len), torch.zeros((batch_size,seq_len),
device=device,dtype=torch.long) device=device,dtype=torch.long)
remain_idx = torch.arange(batch_size, device=device) remain_idx = torch.arange(batch_size, device=device)
for i in range(seq_len): for i in range(seq_len):
# While-loop body for autoregression calculation. # While-loop body for autoregression calculation.
pred_dist, cache = self.body( pred_dist, cache = self.body(
samples, samples,
global_context_embedding=global_context, global_context_embedding=global_context,
sequential_context_embeddings=seq_context, sequential_context_embeddings=seq_context,
cache=cache, cache=cache,
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
is_training=False) is_training=False)
samples = pred_dist.sample() samples = pred_dist.sample()
decoded_tokens[remain_idx,i] = samples[:,-1] decoded_tokens[remain_idx,i] = samples[:,-1]
# Stopping conditions for autoregressive calculation. # Stopping conditions for autoregressive calculation.
if not (decoded_tokens[:,:i+1] != 0).all(-1).any(): if not (decoded_tokens[:,:i+1] != 0).all(-1).any():
break break
# update state, check the new position is zero. # update state, check the new position is zero.
valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0] valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0]
remain_idx = remain_idx[valid_idx] remain_idx = remain_idx[valid_idx]
cache = cache[:,:,valid_idx] cache = cache[:,:,valid_idx]
global_context = global_context[valid_idx] global_context = global_context[valid_idx]
seq_context = seq_context[valid_idx] seq_context = seq_context[valid_idx]
samples = samples[valid_idx] samples = samples[valid_idx]
# decoded_tokens = torch.cat(decoded_tokens,dim=1) # decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens = decoded_tokens[:,:i+1] decoded_tokens = decoded_tokens[:,:i+1]
outputs = self.post_process(decoded_tokens, seq_len, outputs = self.post_process(decoded_tokens, seq_len,
device, only_return_complete) device, only_return_complete)
return outputs return outputs
def post_process(self, polyline, def post_process(self, polyline,
max_seq_len=None, max_seq_len=None,
device=None, device=None,
only_return_complete=True): only_return_complete=True):
''' '''
format the predictions format the predictions
find the mask find the mask
''' '''
# Record completed samples # Record completed samples
complete_samples = (polyline == 0).any(-1) complete_samples = (polyline == 0).any(-1)
# Find number of faces # Find number of faces
sample_seq_length = polyline.shape[-1] sample_seq_length = polyline.shape[-1]
_polyline_mask = torch.arange(sample_seq_length)[None].to(device) _polyline_mask = torch.arange(sample_seq_length)[None].to(device)
# Get largest stopping point for incomplete samples. # Get largest stopping point for incomplete samples.
valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length) valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length)
zero_inds = (polyline == 0).type(torch.int32).argmax(-1) zero_inds = (polyline == 0).type(torch.int32).argmax(-1)
# Real length # Real length
valid_polyline_len[complete_samples] = zero_inds[complete_samples] + 1 valid_polyline_len[complete_samples] = zero_inds[complete_samples] + 1
polyline_mask = _polyline_mask < valid_polyline_len[:, None] polyline_mask = _polyline_mask < valid_polyline_len[:, None]
# Mask faces beyond stopping token with zeros # Mask faces beyond stopping token with zeros
polyline = polyline*polyline_mask polyline = polyline*polyline_mask
# Pad to maximum size with zeros # Pad to maximum size with zeros
pad_size = max_seq_len - sample_seq_length pad_size = max_seq_len - sample_seq_length
polyline = F.pad(polyline, (0, pad_size)) polyline = F.pad(polyline, (0, pad_size))
# polyline_mask = F.pad(polyline_mask, (0, pad_size)) # polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX # XXX
# if only_return_complete: # if only_return_complete:
# polyline = polyline[complete_samples] # polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples] # valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure( # context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context) # lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples] # complete_samples = complete_samples[complete_samples]
# outputs # outputs
outputs = { outputs = {
'completed': complete_samples, 'completed': complete_samples,
'polylines': polyline, 'polylines': polyline,
'polyline_masks': polyline_mask, 'polyline_masks': polyline_mask,
} }
return outputs return outputs
def find_best_sperate_plan(idx,array): def find_best_sperate_plan(idx,array):
h = array[-1] - array[idx] h = array[-1] - array[idx]
w = idx w = idx
cost = h*w cost = h*w
return cost return cost
def get_chunk_idx(polyline_length): def get_chunk_idx(polyline_length):
_polyline_length, polyline_length_idx = torch.sort(polyline_length) _polyline_length, polyline_length_idx = torch.sort(polyline_length)
costs = [] costs = []
for i in range(len(_polyline_length)): for i in range(len(_polyline_length)):
cost = find_best_sperate_plan(i,_polyline_length) cost = find_best_sperate_plan(i,_polyline_length)
costs.append(cost) costs.append(cost)
seperate_point = torch.stack(costs).argmax() seperate_point = torch.stack(costs).argmax()
chunk1 = polyline_length_idx[:seperate_point+1] chunk1 = polyline_length_idx[:seperate_point+1]
chunk2 = polyline_length_idx[seperate_point+1:] chunk2 = polyline_length_idx[seperate_point+1:]
revert_idx = torch.argsort(polyline_length_idx) revert_idx = torch.argsort(polyline_length_idx)
return chunk1, chunk2, revert_idx, _polyline_length[seperate_point] return chunk1, chunk2, revert_idx, _polyline_length[seperate_point]
def assign_bev(feat, idx): def assign_bev(feat, idx):
return feat[idx] return feat[idx]
def assign_batch(batch, idx, size): def assign_batch(batch, idx, size):
new_batch = {} new_batch = {}
for k,v in batch.items(): for k,v in batch.items():
new_batch[k] = v[idx] new_batch[k] = v[idx]
if new_batch[k].ndim > 1: if new_batch[k].ndim > 1:
new_batch[k] = new_batch[k][:,:size] new_batch[k] = new_batch[k][:,:size]
return new_batch return new_batch
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet.models.losses import l1_loss from mmdet.models.losses import l1_loss
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
import mmcv import mmcv
from mmdet.models.builder import LOSSES from mmdet.models.builder import LOSSES
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
@weighted_loss @weighted_loss
def smooth_l1_loss(pred, target, beta=1.0): def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss. """Smooth L1 loss.
Args: Args:
pred (torch.Tensor): The prediction. pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction. target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function. beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
torch.Tensor: Calculated loss torch.Tensor: Calculated loss
""" """
assert beta > 0 assert beta > 0
if target.numel() == 0: if target.numel() == 0:
return pred.sum() * 0 return pred.sum() * 0
assert pred.size() == target.size() assert pred.size() == target.size()
diff = torch.abs(pred - target) diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta) diff - 0.5 * beta)
return loss return loss
@LOSSES.register_module() @LOSSES.register_module()
class LinesLoss(nn.Module): class LinesLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5): def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
""" """
L1 loss. The same as the smooth L1 loss L1 loss. The same as the smooth L1 loss
Args: Args:
reduction (str, optional): The method to reduce the loss. reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum". Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss. loss_weight (float, optional): The weight of loss.
""" """
super(LinesLoss, self).__init__() super(LinesLoss, self).__init__()
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.beta = beta self.beta = beta
def forward(self, def forward(self,
pred, pred,
target, target,
weight=None, weight=None,
avg_factor=None, avg_factor=None,
reduction_override=None): reduction_override=None):
"""Forward function. """Forward function.
Args: Args:
pred (torch.Tensor): The prediction. pred (torch.Tensor): The prediction.
shape: [bs, ...] shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction. target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...] shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None. prediction. Defaults to None.
it's useful when the predictions are not all valid. it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None. the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss. override the original reduction method of the loss.
Defaults to None. Defaults to None.
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
reduction_override if reduction_override else self.reduction) reduction_override if reduction_override else self.reduction)
loss = smooth_l1_loss( loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta) pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
return loss*self.loss_weight return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
@weighted_loss @weighted_loss
def bce(pred, label, class_weight=None): def bce(pred, label, class_weight=None):
""" """
pred: B,nquery,npts pred: B,nquery,npts
label: B,nquery,npts label: B,nquery,npts
""" """
if label.numel() == 0: if label.numel() == 0:
return pred.sum() * 0 return pred.sum() * 0
assert pred.size() == label.size() assert pred.size() == label.size()
loss = F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none') pred, label.float(), pos_weight=class_weight, reduction='none')
return loss return loss
@LOSSES.register_module() @LOSSES.register_module()
class MasksLoss(nn.Module): class MasksLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0): def __init__(self, reduction='mean', loss_weight=1.0):
super(MasksLoss, self).__init__() super(MasksLoss, self).__init__()
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self,
pred, pred,
target, target,
weight=None, weight=None,
avg_factor=None, avg_factor=None,
reduction_override=None): reduction_override=None):
"""Forward function. """Forward function.
Args: Args:
xxx xxx
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
reduction_override if reduction_override else self.reduction) reduction_override if reduction_override else self.reduction)
loss = bce(pred, target, weight, reduction=reduction, loss = bce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor) avg_factor=avg_factor)
return loss*self.loss_weight return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
@weighted_loss @weighted_loss
def ce(pred, label, class_weight=None): def ce(pred, label, class_weight=None):
""" """
pred: B*nquery,npts pred: B*nquery,npts
label: B*nquery, label: B*nquery,
""" """
if label.numel() == 0: if label.numel() == 0:
return pred.sum() * 0 return pred.sum() * 0
loss = F.cross_entropy( loss = F.cross_entropy(
pred, label, weight=class_weight, reduction='none') pred, label, weight=class_weight, reduction='none')
return loss return loss
@LOSSES.register_module() @LOSSES.register_module()
class LenLoss(nn.Module): class LenLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0): def __init__(self, reduction='mean', loss_weight=1.0):
super(LenLoss, self).__init__() super(LenLoss, self).__init__()
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, def forward(self,
pred, pred,
target, target,
weight=None, weight=None,
avg_factor=None, avg_factor=None,
reduction_override=None): reduction_override=None):
"""Forward function. """Forward function.
Args: Args:
xxx xxx
""" """
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = ( reduction = (
reduction_override if reduction_override else self.reduction) reduction_override if reduction_override else self.reduction)
loss = ce(pred, target, weight, reduction=reduction, loss = ce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor) avg_factor=avg_factor)
return loss*self.loss_weight return loss*self.loss_weight
\ No newline at end of file
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch.nn as nn import torch.nn as nn
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet.utils import get_root_logger from mmdet.utils import get_root_logger
from mmdet3d.models.builder import DETECTORS from mmdet3d.models.builder import DETECTORS
MAPPERS = DETECTORS MAPPERS = DETECTORS
class BaseMapper(nn.Module, metaclass=ABCMeta): class BaseMapper(nn.Module, metaclass=ABCMeta):
"""Base class for mappers.""" """Base class for mappers."""
def __init__(self): def __init__(self):
super(BaseMapper, self).__init__() super(BaseMapper, self).__init__()
self.fp16_enabled = False self.fp16_enabled = False
@property @property
def with_neck(self): def with_neck(self):
"""bool: whether the detector has a neck""" """bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled # TODO: these properties need to be carefully handled
# for both single stage & two stage detectors # for both single stage & two stage detectors
@property @property
def with_shared_head(self): def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head""" """bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property @property
def with_bbox(self): def with_bbox(self):
"""bool: whether the detector has a bbox head""" """bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox) return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None)) or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property @property
def with_mask(self): def with_mask(self):
"""bool: whether the detector has a mask head""" """bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask) return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None)) or (hasattr(self, 'mask_head') and self.mask_head is not None))
#@abstractmethod #@abstractmethod
def extract_feat(self, imgs): def extract_feat(self, imgs):
"""Extract features from images.""" """Extract features from images."""
pass pass
def forward_train(self, *args, **kwargs): def forward_train(self, *args, **kwargs):
pass pass
#@abstractmethod #@abstractmethod
def simple_test(self, img, img_metas, **kwargs): def simple_test(self, img, img_metas, **kwargs):
pass pass
#@abstractmethod #@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs): def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation.""" """Test function with test time augmentation."""
pass pass
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize the weights in detector. """Initialize the weights in detector.
Args: Args:
pretrained (str, optional): Path to pre-trained weights. pretrained (str, optional): Path to pre-trained weights.
Defaults to None. Defaults to None.
""" """
if pretrained is not None: if pretrained is not None:
logger = get_root_logger() logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger) print_log(f'load model from: {pretrained}', logger=logger)
def forward_test(self, *args, **kwargs): def forward_test(self, *args, **kwargs):
""" """
Args: Args:
""" """
if True: if True:
self.simple_test() self.simple_test()
else: else:
self.aug_test() self.aug_test()
# @auto_fp16(apply_to=('img', )) # @auto_fp16(apply_to=('img', ))
def forward(self, *args, return_loss=True, **kwargs): def forward(self, *args, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending """Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``. on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations. the outer list indicating test time augmentations.
""" """
if return_loss: if return_loss:
return self.forward_train(*args, **kwargs) return self.forward_train(*args, **kwargs)
else: else:
kwargs.pop('rescale') kwargs.pop('rescale')
return self.forward_test(*args, **kwargs) return self.forward_test(*args, **kwargs)
def train_step(self, data_dict, optimizer): def train_step(self, data_dict, optimizer):
"""The iteration step during training. """The iteration step during training.
This method defines an iteration step during training, except for the This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in including back propagation and optimizer updating is also defined in
this method, such as GAN. this method, such as GAN.
Args: Args:
data_dict (dict): The output of dataloader. data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused runner is passed to ``train_step()``. This argument is unused
and reserved. and reserved.
Returns: Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``. ``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \ - ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses. weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the - ``log_vars`` contains all the variables to be sent to the
logger. logger.
- ``num_samples`` indicates the batch size (when the model is \ - ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \ DDP, it means the batch size on each GPU), which is used for \
averaging the logs. averaging the logs.
""" """
loss, log_vars, num_samples = self(**data_dict) loss, log_vars, num_samples = self(**data_dict)
outputs = dict( outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples) loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs return outputs
def val_step(self, data, optimizer): def val_step(self, data, optimizer):
"""The iteration step during validation. """The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook. not implemented with this method, but an evaluation hook.
""" """
loss, log_vars, num_samples = self(**data) loss, log_vars, num_samples = self(**data)
outputs = dict( outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples) loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs return outputs
def show_result(self, def show_result(self,
**kwargs): **kwargs):
img = None img = None
return img return img
\ No newline at end of file
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torchvision.models.resnet import resnet18, resnet50 from torchvision.models.resnet import resnet18, resnet50
from mmdet3d.models.builder import (build_backbone, build_head, from mmdet3d.models.builder import (build_backbone, build_head,
build_neck) build_neck)
from .base_mapper import BaseMapper, MAPPERS from .base_mapper import BaseMapper, MAPPERS
@MAPPERS.register_module() @MAPPERS.register_module()
class VectorMapNet(BaseMapper): class VectorMapNet(BaseMapper):
def __init__(self, def __init__(self,
backbone_cfg=dict(), backbone_cfg=dict(),
head_cfg=dict( head_cfg=dict(
vert_net_cfg=dict(), vert_net_cfg=dict(),
face_net_cfg=dict(), face_net_cfg=dict(),
), ),
neck_input_channels=128, neck_input_channels=128,
neck_cfg=None, neck_cfg=None,
with_auxiliary_head=False, with_auxiliary_head=False,
only_det=False, only_det=False,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None, pretrained=None,
model_name=None, **kwargs): model_name=None, **kwargs):
super(VectorMapNet, self).__init__() super(VectorMapNet, self).__init__()
#Attribute #Attribute
self.model_name = model_name self.model_name = model_name
self.last_epoch = None self.last_epoch = None
self.only_det = only_det self.only_det = only_det
self.backbone = build_backbone(backbone_cfg) self.backbone = build_backbone(backbone_cfg)
if neck_cfg is not None: if neck_cfg is not None:
self.neck_neck = build_backbone(neck_cfg.backbone) self.neck_neck = build_backbone(neck_cfg.backbone)
self.neck_neck.conv1 = nn.Conv2d( self.neck_neck.conv1 = nn.Conv2d(
neck_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) neck_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.neck_project = build_neck(neck_cfg.neck) self.neck_project = build_neck(neck_cfg.neck)
self.neck = self.multiscale_neck self.neck = self.multiscale_neck
else: else:
trunk = resnet18(pretrained=False, zero_init_residual=True) trunk = resnet18(pretrained=False, zero_init_residual=True)
self.neck = nn.Sequential( self.neck = nn.Sequential(
nn.Conv2d(neck_input_channels, 64, kernel_size=(7, 7), stride=( nn.Conv2d(neck_input_channels, 64, kernel_size=(7, 7), stride=(
2, 2), padding=(3, 3), bias=False), 2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64), nn.BatchNorm2d(64),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1, nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
dilation=1, ceil_mode=False), dilation=1, ceil_mode=False),
trunk.layer1, trunk.layer1,
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.Conv2d(64, 128, kernel_size=1, bias=False),
) )
# BEV # BEV
if hasattr(self.backbone,'bev_w'): if hasattr(self.backbone,'bev_w'):
self.bev_w = self.backbone.bev_w self.bev_w = self.backbone.bev_w
self.bev_h = self.backbone.bev_h self.bev_h = self.backbone.bev_h
self.head = build_head(head_cfg) self.head = build_head(head_cfg)
def multiscale_neck(self, bev_embedding): def multiscale_neck(self, bev_embedding):
multi_feat = self.neck_neck(bev_embedding) multi_feat = self.neck_neck(bev_embedding)
multi_feat = self.neck_project(multi_feat) multi_feat = self.neck_project(multi_feat)
return multi_feat return multi_feat
def forward_train(self, img, polys, points=None, img_metas=None, **kwargs): def forward_train(self, img, polys, points=None, img_metas=None, **kwargs):
''' '''
Args: Args:
img: torch.Tensor of shape [B, N, 3, H, W] img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams N: number of cams
vectors: list[list[Tuple(lines, length, label)]] vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2]. - lines: np.array of shape [num_points, 2].
- length: int - length: int
- label: int - label: int
len(vectors) = batch_size len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b len(vectors[_b]) = num of lines in sample _b
img_metas: img_metas:
img_metas['lidar2img']: [B, N, 4, 4] img_metas['lidar2img']: [B, N, 4, 4]
Out: Out:
loss, log_vars, num_sample loss, log_vars, num_sample
''' '''
# prepare labels and images # prepare labels and images
batch, img, img_metas, valid_idx, points = self.batch_data( batch, img, img_metas, valid_idx, points = self.batch_data(
polys, img, img_metas, img.device, points) polys, img, img_metas, img.device, points)
# corner cases use hard code to prevent code fail # corner cases use hard code to prevent code fail
if self.last_epoch is None: if self.last_epoch is None:
self.last_epoch = [batch, img, img_metas, valid_idx, points] self.last_epoch = [batch, img, img_metas, valid_idx, points]
if len(valid_idx)==0: if len(valid_idx)==0:
batch, img, img_metas, valid_idx, points = self.last_epoch batch, img, img_metas, valid_idx, points = self.last_epoch
else: else:
del self.last_epoch del self.last_epoch
self.last_epoch = [batch, img, img_metas, valid_idx, points] self.last_epoch = [batch, img, img_metas, valid_idx, points]
# Backbone # Backbone
_bev_feats = self.backbone(img, img_metas=img_metas, points=points) _bev_feats = self.backbone(img, img_metas=img_metas, points=points)
img_shape = \ img_shape = \
[_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])] [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck # Neck
bev_feats = self.neck(_bev_feats) bev_feats = self.neck(_bev_feats)
preds_dict, losses_dict = \ preds_dict, losses_dict = \
self.head(batch, self.head(batch,
context={ context={
'bev_embeddings': bev_feats, 'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:], 'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, 'img_shape': img_shape,
'raw_bev_embeddings': _bev_feats}, 'raw_bev_embeddings': _bev_feats},
only_det=self.only_det) only_det=self.only_det)
# format outputs # format outputs
loss = 0 loss = 0
for name, var in losses_dict.items(): for name, var in losses_dict.items():
loss = loss + var loss = loss + var
# update the log # update the log
log_vars = {k: v.item() for k, v in losses_dict.items()} log_vars = {k: v.item() for k, v in losses_dict.items()}
log_vars.update({'total': loss.item()}) log_vars.update({'total': loss.item()})
num_sample = img.size(0) num_sample = img.size(0)
return loss, log_vars, num_sample return loss, log_vars, num_sample
@torch.no_grad() @torch.no_grad()
def forward_test(self, img, polys=None, points=None, img_metas=None, **kwargs): def forward_test(self, img, polys=None, points=None, img_metas=None, **kwargs):
''' '''
inference pipeline inference pipeline
''' '''
# prepare labels and images # prepare labels and images
token = [] token = []
for img_meta in img_metas: for img_meta in img_metas:
token.append(img_meta['token']) token.append(img_meta['token'])
_bev_feats = self.backbone(img, img_metas, points=points) _bev_feats = self.backbone(img, img_metas, points=points)
img_shape = [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])] img_shape = [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck # Neck
bev_feats = self.neck(_bev_feats) bev_feats = self.neck(_bev_feats)
context = {'bev_embeddings': bev_feats, context = {'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:], 'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, # XXX 'img_shape': img_shape, # XXX
'raw_bev_embeddings': _bev_feats} 'raw_bev_embeddings': _bev_feats}
preds_dict = self.head(batch={}, preds_dict = self.head(batch={},
context=context, context=context,
condition_on_det=True, condition_on_det=True,
gt_condition=False, gt_condition=False,
only_det=self.only_det) only_det=self.only_det)
# Hard Code # Hard Code
if preds_dict is None: if preds_dict is None:
return [None] return [None]
results_list = self.head.post_process(preds_dict, token, only_det=self.only_det) results_list = self.head.post_process(preds_dict, token, only_det=self.only_det)
return results_list return results_list
def batch_data(self, polys, imgs, img_metas, device, points=None): def batch_data(self, polys, imgs, img_metas, device, points=None):
# filter none vector's case # filter none vector's case
valid_idx = [i for i in range(len(polys)) if len(polys[i])] valid_idx = [i for i in range(len(polys)) if len(polys[i])]
imgs = imgs[valid_idx] imgs = imgs[valid_idx]
img_metas = [img_metas[i] for i in valid_idx] img_metas = [img_metas[i] for i in valid_idx]
polys = [polys[i] for i in valid_idx] polys = [polys[i] for i in valid_idx]
if points is not None: if points is not None:
points = [points[i] for i in valid_idx] points = [points[i] for i in valid_idx]
points = self.batch_points(points) points = self.batch_points(points)
if len(valid_idx) == 0: if len(valid_idx) == 0:
return None, None, None, valid_idx, None return None, None, None, valid_idx, None
batch = {} batch = {}
batch['det'] = format_det(polys,device) batch['det'] = format_det(polys,device)
batch['gen'] = format_gen(polys,device) batch['gen'] = format_gen(polys,device)
return batch, imgs, img_metas, valid_idx, points return batch, imgs, img_metas, valid_idx, points
def batch_points(self, points): def batch_points(self, points):
pad_points = pad_sequence(points, batch_first=True) pad_points = pad_sequence(points, batch_first=True)
points_mask = torch.zeros_like(pad_points[:,:,0]).bool() points_mask = torch.zeros_like(pad_points[:,:,0]).bool()
for i in range(len(points)): for i in range(len(points)):
valid_num = points[i].shape[0] valid_num = points[i].shape[0]
points_mask[i][:valid_num] = True points_mask[i][:valid_num] = True
return (pad_points, points_mask) return (pad_points, points_mask)
def format_det(polys, device): def format_det(polys, device):
batch = { batch = {
'class_label':[], 'class_label':[],
'batch_idx':[], 'batch_idx':[],
'bbox': [], 'bbox': [],
} }
for batch_idx, poly in enumerate(polys): for batch_idx, poly in enumerate(polys):
keypoint_label = torch.from_numpy(poly['det_label']).to(device) keypoint_label = torch.from_numpy(poly['det_label']).to(device)
keypoint = torch.from_numpy(poly['keypoint']).to(device) keypoint = torch.from_numpy(poly['keypoint']).to(device)
batch['class_label'].append(keypoint_label) batch['class_label'].append(keypoint_label)
batch['bbox'].append(keypoint) batch['bbox'].append(keypoint)
return batch return batch
def format_gen(polys,device): def format_gen(polys,device):
line_cls = [] line_cls = []
polylines, polyline_masks, polyline_weights = [], [], [] polylines, polyline_masks, polyline_weights = [], [], []
bbox, line_cls, line_bs_idx = [], [], [] bbox, line_cls, line_bs_idx = [], [], []
for batch_idx, poly in enumerate(polys): for batch_idx, poly in enumerate(polys):
# convert to cuda tensor # convert to cuda tensor
for k in poly.keys(): for k in poly.keys():
if isinstance(poly[k],np.ndarray): if isinstance(poly[k],np.ndarray):
poly[k] = torch.from_numpy(poly[k]).to(device) poly[k] = torch.from_numpy(poly[k]).to(device)
else: else:
poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]] poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]]
line_cls += poly['gen_label'] line_cls += poly['gen_label']
line_bs_idx += [batch_idx]*len(poly['gen_label']) line_bs_idx += [batch_idx]*len(poly['gen_label'])
# condition # condition
bbox += poly['qkeypoint'] bbox += poly['qkeypoint']
# out # out
polylines += poly['polylines'] polylines += poly['polylines']
polyline_masks += poly['polyline_masks'] polyline_masks += poly['polyline_masks']
polyline_weights += poly['polyline_weights'] polyline_weights += poly['polyline_weights']
batch = {} batch = {}
batch['lines_bs_idx'] = torch.tensor( batch['lines_bs_idx'] = torch.tensor(
line_bs_idx, dtype=torch.long, device=device) line_bs_idx, dtype=torch.long, device=device)
batch['lines_cls'] = torch.tensor( batch['lines_cls'] = torch.tensor(
line_cls, dtype=torch.long, device=device) line_cls, dtype=torch.long, device=device)
batch['bbox_flat'] = torch.stack(bbox, 0) batch['bbox_flat'] = torch.stack(bbox, 0)
# padding # padding
batch['polylines'] = pad_sequence(polylines, batch_first=True) batch['polylines'] = pad_sequence(polylines, batch_first=True)
batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True) batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True)
batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True) batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True)
return batch return batch
\ No newline at end of file
from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_ from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_
from .base_transformer import PlaceHolderEncoder from .base_transformer import PlaceHolderEncoder
\ No newline at end of file
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION, from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE) TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention, from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
TransformerLayerSequence, TransformerLayerSequence,
build_transformer_layer_sequence) build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER from mmdet.models.utils.builder import TRANSFORMER
@TRANSFORMER_LAYER_SEQUENCE.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class PlaceHolderEncoder(nn.Module): class PlaceHolderEncoder(nn.Module):
def __init__(self, *args, embed_dims=None, **kwargs): def __init__(self, *args, embed_dims=None, **kwargs):
super(PlaceHolderEncoder, self).__init__() super(PlaceHolderEncoder, self).__init__()
self.embed_dims = embed_dims self.embed_dims = embed_dims
def forward(self, *args, query=None, **kwargs): def forward(self, *args, query=None, **kwargs):
return query return query
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER, from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE) TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence, TransformerLayerSequence,
build_transformer_layer_sequence) build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from torch.nn.init import normal_ from torch.nn.init import normal_
from mmdet.models.utils.builder import TRANSFORMER from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import Transformer from mmdet.models.utils.transformer import Transformer
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError: except ImportError:
warnings.warn( warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to ' '`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV') '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from .fp16_dattn import MultiScaleDeformableAttentionFp16 from .fp16_dattn import MultiScaleDeformableAttentionFp16
def inverse_sigmoid(x, eps=1e-5): def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid. """Inverse function of sigmoid.
Args: Args:
x (Tensor): The tensor to do the x (Tensor): The tensor to do the
inverse. inverse.
eps (float): EPS avoid numerical eps (float): EPS avoid numerical
overflow. Defaults 1e-5. overflow. Defaults 1e-5.
Returns: Returns:
Tensor: The x has passed the inverse Tensor: The x has passed the inverse
function of sigmoid, has same function of sigmoid, has same
shape with input. shape with input.
""" """
x = x.clamp(min=0, max=1) x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps) x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps) x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2) return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder_(TransformerLayerSequence): class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
"""Implements the decoder in DETR transformer. """Implements the decoder in DETR transformer.
Args: Args:
return_intermediate (bool): Whether to return intermediate outputs. return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default: coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`. `LN`.
""" """
def __init__(self, *args, def __init__(self, *args,
return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs): return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs):
super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs) super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate self.return_intermediate = return_intermediate
self.coord_dim = coord_dim self.coord_dim = coord_dim
self.kp_coord_dim = kp_coord_dim self.kp_coord_dim = kp_coord_dim
def forward(self, def forward(self,
query, query,
*args, *args,
reference_points=None, reference_points=None,
valid_ratios=None, valid_ratios=None,
reg_branches=None, reg_branches=None,
**kwargs): **kwargs):
"""Forward function for `TransformerDecoder`. """Forward function for `TransformerDecoder`.
Args: Args:
query (Tensor): Input query with shape query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`. `(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference reference_points (Tensor): The reference
points of offset. has shape points of offset. has shape
(bs, num_query, 4) when as_two_stage, (bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2). otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid valid_ratios (Tensor): The radios of valid
points on the feature map, has shape points on the feature map, has shape
(bs, num_levels, 2) (bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would refining the regression results. Only would
be passed when with_box_refine is True, be passed when with_box_refine is True,
otherwise would be passed a `None`. otherwise would be passed a `None`.
Returns: Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims]. [num_layers, num_query, bs, embed_dims].
""" """
output = query output = query
intermediate = [] intermediate = []
intermediate_reference_points = [] intermediate_reference_points = []
for lid, layer in enumerate(self.layers): for lid, layer in enumerate(self.layers):
reference_points_input = \ reference_points_input = \
reference_points[:, :, None,:self.kp_coord_dim] * \ reference_points[:, :, None,:self.kp_coord_dim] * \
valid_ratios[:, None,:,:self.kp_coord_dim] valid_ratios[:, None,:,:self.kp_coord_dim]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2: # if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output = layer( output = layer(
output, output,
*args, *args,
reference_points=reference_points_input[...,:self.kp_coord_dim], reference_points=reference_points_input[...,:self.kp_coord_dim],
**kwargs) **kwargs)
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
if reg_branches is not None: if reg_branches is not None:
tmp = reg_branches[lid](output) tmp = reg_branches[lid](output)
new_reference_points = tmp new_reference_points = tmp
new_reference_points[..., :self.kp_coord_dim] = tmp[ new_reference_points[..., :self.kp_coord_dim] = tmp[
..., :self.kp_coord_dim] + inverse_sigmoid(reference_points) ..., :self.kp_coord_dim] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid() new_reference_points = new_reference_points.sigmoid()
if reference_points.shape[-1] == 3 and self.kp_coord_dim==2: if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
reference_points[...,-1] = tmp[...,-1].sigmoid().detach() reference_points[...,-1] = tmp[...,-1].sigmoid().detach()
reference_points[...,:self.coord_dim] = new_reference_points.detach() reference_points[...,:self.coord_dim] = new_reference_points.detach()
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
if self.return_intermediate: if self.return_intermediate:
intermediate.append(output) intermediate.append(output)
intermediate_reference_points.append(reference_points) intermediate_reference_points.append(reference_points)
if self.return_intermediate: if self.return_intermediate:
return torch.stack(intermediate), torch.stack( return torch.stack(intermediate), torch.stack(
intermediate_reference_points) intermediate_reference_points)
return output, reference_points return output, reference_points
@TRANSFORMER.register_module() @TRANSFORMER.register_module()
class DeformableDetrTransformer_(Transformer): class DeformableDetrTransformer_(Transformer):
"""Implements the DeformableDETR transformer. """Implements the DeformableDETR transformer.
Args: Args:
as_two_stage (bool): Generate query from encoder features. as_two_stage (bool): Generate query from encoder features.
Default: False. Default: False.
num_feature_levels (int): Number of feature maps from FPN: num_feature_levels (int): Number of feature maps from FPN:
Default: 4. Default: 4.
two_stage_num_proposals (int): Number of proposals when set two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300. `as_two_stage` as True. Default: 300.
""" """
def __init__(self, def __init__(self,
as_two_stage=False, as_two_stage=False,
num_feature_levels=1, num_feature_levels=1,
two_stage_num_proposals=300, two_stage_num_proposals=300,
coord_dim=2, coord_dim=2,
**kwargs): **kwargs):
super(DeformableDetrTransformer_, self).__init__(**kwargs) super(DeformableDetrTransformer_, self).__init__(**kwargs)
self.as_two_stage = as_two_stage self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims self.embed_dims = self.encoder.embed_dims
self.coord_dim = coord_dim self.coord_dim = coord_dim
self.init_layers() self.init_layers()
def init_layers(self): def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer.""" """Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter( self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims)) torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage: if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims) self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2, self.pos_trans = nn.Linear(self.embed_dims * 2,
self.embed_dims * 2) self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else: else:
self.reference_points_embed = nn.Linear(self.embed_dims, self.coord_dim) self.reference_points_embed = nn.Linear(self.embed_dims, self.coord_dim)
def init_weights(self): def init_weights(self):
"""Initialize the transformer weights.""" """Initialize the transformer weights."""
for p in self.parameters(): for p in self.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
for m in self.modules(): for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention): if isinstance(m, MultiScaleDeformableAttention):
m.init_weights() m.init_weights()
elif isinstance(m,MultiScaleDeformableAttentionFp16): elif isinstance(m,MultiScaleDeformableAttentionFp16):
m.init_weights() m.init_weights()
if not self.as_two_stage: if not self.as_two_stage:
xavier_init(self.reference_points_embed, distribution='uniform', bias=0.) xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)
normal_(self.level_embeds) normal_(self.level_embeds)
@staticmethod @staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device): def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder. """Get the reference points used in decoder.
Args: Args:
spatial_shapes (Tensor): The shape of all spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2). feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid valid_ratios (Tensor): The radios of valid
points on the feature map, has shape points on the feature map, has shape
(bs, num_levels, 2) (bs, num_levels, 2)
device (obj:`device`): The device where device (obj:`device`): The device where
reference_points should be. reference_points should be.
Returns: Returns:
Tensor: reference points used in decoder, has \ Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2). shape (bs, num_keys, num_levels, 2).
""" """
reference_points_list = [] reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes): for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5 # TODO check this 0.5
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace( torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device), 0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace( torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device)) 0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / ( ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H) valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / ( ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W) valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1) ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref) reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1) reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None] reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points return reference_points
def get_valid_ratio(self, mask): def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level.""" """Get the valid radios of feature maps of all level."""
_, H, W = mask.shape _, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1) valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1) valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio return valid_ratio
def get_proposal_pos_embed(self, def get_proposal_pos_embed(self,
proposals, proposals,
num_pos_feats=128, num_pos_feats=128,
temperature=10000): temperature=10000):
"""Get the position embedding of proposal.""" """Get the position embedding of proposal."""
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange( dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device) num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4 # N, L, 4
proposals = proposals.sigmoid() * scale proposals = proposals.sigmoid() * scale
# N, L, 4, 128 # N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2 # N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
dim=4).flatten(2) dim=4).flatten(2)
return pos return pos
def forward(self, def forward(self,
mlvl_feats, mlvl_feats,
mlvl_masks, mlvl_masks,
query_embed, query_embed,
mlvl_pos_embeds, mlvl_pos_embeds,
reg_branches=None, reg_branches=None,
cls_branches=None, cls_branches=None,
**kwargs): **kwargs):
"""Forward function for `Transformer`. """Forward function for `Transformer`.
Args: Args:
mlvl_feats (list(Tensor)): Input queries from mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape different level. Each element has shape
[bs, embed_dims, h, w]. [bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder, different level used for encoder and decoder,
each element has shape [bs, h, w]. each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder, query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c]. with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape of feats from different level, has the shape
[bs, embed_dims, h, w]. [bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would feature maps from each decoder layer. Only would
be passed when be passed when
`with_box_refine` is True. Default to None. `with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would for feature maps from each decoder layer. Only would
be passed when `as_two_stage` be passed when `as_two_stage`
is True. Default to None. is True. Default to None.
Returns: Returns:
tuple[Tensor]: results of decoder containing the following tensor. tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If - inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \ return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \ (num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims). shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \ - init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4). points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \ - inter_references_out: The internal value of reference \
points in decoder, has shape \ points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims) (num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \ - enc_outputs_class: The classification score of \
proposals generated from \ proposals generated from \
encoder's feature maps, has shape \ encoder's feature maps, has shape \
(batch, h*w, num_classes). \ (batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \ Only would be returned when `as_two_stage` is True, \
otherwise None. otherwise None.
- enc_outputs_coord_unact: The regression results \ - enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \ generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \ (batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \ be returned when `as_two_stage` is True, \
otherwise None. otherwise None.
""" """
assert self.as_two_stage or query_embed is not None assert self.as_two_stage or query_embed is not None
feat_flatten = [] feat_flatten = []
mask_flatten = [] mask_flatten = []
lvl_pos_embed_flatten = [] lvl_pos_embed_flatten = []
spatial_shapes = [] spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate( for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape bs, c, h, w = feat.shape
spatial_shape = (h, w) spatial_shape = (h, w)
spatial_shapes.append(spatial_shape) spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2) feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1) mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2) pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed) lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat) feat_flatten.append(feat)
mask_flatten.append(mask) mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1) feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor( spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device) spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros( level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack( valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1) [self.get_valid_ratio(m) for m in mlvl_masks], 1)
# reference_points = \ # reference_points = \
# self.get_reference_points(spatial_shapes, # self.get_reference_points(spatial_shapes,
# valid_ratios, # valid_ratios,
# device=feat.device) # device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( # lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims) # 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder( # memory = self.encoder(
# query=feat_flatten, # query=feat_flatten,
# key=None, # key=None,
# value=None, # value=None,
# query_pos=lvl_pos_embed_flatten, # query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten, # query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes, # spatial_shapes=spatial_shapes,
# reference_points=reference_points, # reference_points=reference_points,
# level_start_index=level_start_index, # level_start_index=level_start_index,
# valid_ratios=valid_ratios, # valid_ratios=valid_ratios,
# **kwargs) # **kwargs)
memory = feat_flatten.permute(1, 0, 2) memory = feat_flatten.permute(1, 0, 2)
bs, _, c = memory.shape bs, _, c = memory.shape
query_pos, query = torch.split(query_embed, c, dim=-1) query_pos, query = torch.split(query_embed, c, dim=-1)
reference_points = self.reference_points_embed(query_pos).sigmoid() reference_points = self.reference_points_embed(query_pos).sigmoid()
init_reference_out = reference_points init_reference_out = reference_points
# decoder # decoder
query = query.permute(1, 0, 2) query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2) memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder( inter_states, inter_references = self.decoder(
query=query, query=query,
key=None, key=None,
value=memory, value=memory,
query_pos=query_pos, query_pos=query_pos,
key_padding_mask=mask_flatten, key_padding_mask=mask_flatten,
reference_points=reference_points, reference_points=reference_points,
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
level_start_index=level_start_index, level_start_index=level_start_index,
valid_ratios=valid_ratios, valid_ratios=valid_ratios,
reg_branches=reg_branches, reg_branches=reg_branches,
**kwargs) **kwargs)
inter_references_out = inter_references inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out return inter_states, init_reference_out, inter_references_out
\ No newline at end of file
from turtle import forward from turtle import forward
import warnings import warnings
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError: except ImportError:
warnings.warn( warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to ' '`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV') '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmcv.runner import force_fp32, auto_fp16 from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_attention from mmcv.cnn.bricks.transformer import build_attention
import math import math
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmcv.utils import ext_loader from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
@ATTENTION.register_module() @ATTENTION.register_module()
class MultiScaleDeformableAttentionFp16(BaseModule): class MultiScaleDeformableAttentionFp16(BaseModule):
def __init__(self, attn_cfg=None,init_cfg=None,**kwarg): def __init__(self, attn_cfg=None,init_cfg=None,**kwarg):
super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg) super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg)
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
self.deformable_attention = build_attention(attn_cfg) self.deformable_attention = build_attention(attn_cfg)
self.deformable_attention.init_weights() self.deformable_attention.init_weights()
self.fp16_enabled = False self.fp16_enabled = False
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity')) @force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity'))
def forward(self, query, def forward(self, query,
key=None, key=None,
value=None, value=None,
identity=None, identity=None,
query_pos=None, query_pos=None,
key_padding_mask=None, key_padding_mask=None,
reference_points=None, reference_points=None,
spatial_shapes=None, spatial_shapes=None,
level_start_index=None, level_start_index=None,
**kwargs): **kwargs):
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
return self.deformable_attention(query, return self.deformable_attention(query,
key=key, key=key,
value=value, value=value,
identity=identity, identity=identity,
query_pos=query_pos, query_pos=query_pos,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
reference_points=reference_points, reference_points=reference_points,
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
level_start_index=level_start_index,**kwargs) level_start_index=level_start_index,**kwargs)
class MultiScaleDeformableAttnFunctionFp32(Function): class MultiScaleDeformableAttnFunctionFp32(Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index, def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step): sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention. """GPU version of multi-scale deformable attention.
Args: Args:
value (Tensor): The value has shape value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads) (bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2), each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w) last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points, sampling_locations (Tensor): The location of sampling points,
has shape has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2), (bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y). the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points), (bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column. im2col_step (Tensor): The step used in image to column.
Returns: Returns:
Tensor: has shape (bs, num_queries, embed_dims) Tensor: has shape (bs, num_queries, embed_dims)
""" """
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward( output = ext_module.ms_deform_attn_forward(
value, value,
value_spatial_shapes, value_spatial_shapes,
value_level_start_index, value_level_start_index,
sampling_locations, sampling_locations,
attention_weights, attention_weights,
im2col_step=ctx.im2col_step) im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations, value_level_start_index, sampling_locations,
attention_weights) attention_weights)
return output return output
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
"""GPU version of backward function. """GPU version of backward function.
Args: Args:
grad_output (Tensor): Gradient grad_output (Tensor): Gradient
of output tensor of forward. of output tensor of forward.
Returns: Returns:
Tuple[Tensor]: Gradient Tuple[Tensor]: Gradient
of input tensors in forward. of input tensors in forward.
""" """
value, value_spatial_shapes, value_level_start_index,\ value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value) grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations) grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights) grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward( ext_module.ms_deform_attn_backward(
value, value,
value_spatial_shapes, value_spatial_shapes,
value_level_start_index, value_level_start_index,
sampling_locations, sampling_locations,
attention_weights, attention_weights,
grad_output.contiguous(), grad_output.contiguous(),
grad_value, grad_value,
grad_sampling_loc, grad_sampling_loc,
grad_attn_weight, grad_attn_weight,
im2col_step=ctx.im2col_step) im2col_step=ctx.im2col_step)
return grad_value, None, None, \ return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights): sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention. """CPU version of multi-scale deformable attention.
Args: Args:
value (Tensor): The value has shape value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads) (bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2), each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w) last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points, sampling_locations (Tensor): The location of sampling points,
has shape has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2), (bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y). the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points), (bs ,num_queries, num_heads, num_levels, num_points),
Returns: Returns:
Tensor: has shape (bs, num_queries, embed_dims) Tensor: has shape (bs, num_queries, embed_dims)
""" """
bs, _, num_heads, embed_dims = value.shape bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\ _, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1) dim=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes): for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims -> # bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims -> # bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ -> # bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_ # bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape( value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_) bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 -> # bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2 # bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1) level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points # bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample( sampling_value_l_ = F.grid_sample(
value_l_, value_l_,
sampling_grid_l_, sampling_grid_l_,
mode='bilinear', mode='bilinear',
padding_mode='zeros', padding_mode='zeros',
align_corners=False) align_corners=False)
sampling_value_list.append(sampling_value_l_) sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) -> # (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points) # (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape( attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points) bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims, attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries) num_queries)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
@ATTENTION.register_module() @ATTENTION.register_module()
class MultiScaleDeformableAttentionFP32(BaseModule): class MultiScaleDeformableAttentionFP32(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR: """An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection. Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_. <https://arxiv.org/pdf/2010.04159.pdf>`_.
Args: Args:
embed_dims (int): The embedding dimension of Attention. embed_dims (int): The embedding dimension of Attention.
Default: 256. Default: 256.
num_heads (int): Parallel attention heads. Default: 64. num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in num_levels (int): The number of feature map used in
Attention. Default: 4. Attention. Default: 4.
num_points (int): The number of sampling points for num_points (int): The number of sampling points for
each query in each head. Default: 4. each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column. im2col_step (int): The step used in image_to_column.
Default: 64. Default: 64.
dropout (float): A Dropout layer on `inp_identity`. dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1. Default: 0.1.
batch_first (bool): Key, Query and Value are shape of batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False. or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer. norm_cfg (dict): Config dict for normalization layer.
Default: None. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
""" """
def __init__(self, def __init__(self,
embed_dims=256, embed_dims=256,
num_heads=8, num_heads=8,
num_levels=4, num_levels=4,
num_points=4, num_points=4,
im2col_step=64, im2col_step=64,
dropout=0.1, dropout=0.1,
batch_first=False, batch_first=False,
norm_cfg=None, norm_cfg=None,
init_cfg=None): init_cfg=None):
super().__init__(init_cfg) super().__init__(init_cfg)
if embed_dims % num_heads != 0: if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, ' raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}') f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2 # you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation # which is more efficient in the CUDA implementation
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError( raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format( 'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n))) n, type(n)))
return (n & (n - 1) == 0) and n != 0 return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head): if not _is_power_of_2(dim_per_head):
warnings.warn( warnings.warn(
"You'd better set embed_dims in " "You'd better set embed_dims in "
'MultiScaleDeformAttention to make ' 'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 ' 'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.') 'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step self.im2col_step = im2col_step
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.num_levels = num_levels self.num_levels = num_levels
self.num_heads = num_heads self.num_heads = num_heads
self.num_points = num_points self.num_points = num_points
self.sampling_offsets = nn.Linear( self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2) embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims, self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points) num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims) self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims) self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
"""Default initialization for Parameters of Module.""" """Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.) constant_init(self.sampling_offsets, 0.)
thetas = torch.arange( thetas = torch.arange(
self.num_heads, self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads) dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view( grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1, self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1) 2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points): for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1 grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1) self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.) constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.) xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.) xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True self._is_init = True
@deprecated_api_warning({'residual': 'identity'}, @deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention') cls_name='MultiScaleDeformableAttention')
def forward(self, def forward(self,
query, query,
key=None, key=None,
value=None, value=None,
identity=None, identity=None,
query_pos=None, query_pos=None,
key_padding_mask=None, key_padding_mask=None,
reference_points=None, reference_points=None,
spatial_shapes=None, spatial_shapes=None,
level_start_index=None, level_start_index=None,
**kwargs): **kwargs):
"""Forward Function of MultiScaleDeformAttention. """Forward Function of MultiScaleDeformAttention.
Args: Args:
query (Tensor): Query of Transformer with shape query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims). (num_query, bs, embed_dims).
key (Tensor): The key tensor with shape key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`. `(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. `(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None, same shape as `query`. Default None. If None,
`query` will be used. `query` will be used.
query_pos (Tensor): The positional encoding for `query`. query_pos (Tensor): The positional encoding for `query`.
Default: None. Default: None.
key_pos (Tensor): The positional encoding for `key`. Default key_pos (Tensor): The positional encoding for `key`. Default
None. None.
reference_points (Tensor): The normalized reference reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2), points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0), all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area. bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to additional two dimensions is (w, h) to
form reference boxes. form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key]. shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2), different levels. With shape (num_levels, 2),
last dimension represents (h, w). last dimension represents (h, w).
level_start_index (Tensor): The start index of each level. level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns: Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims]. Tensor: forwarded results with shape [num_query, bs, embed_dims].
""" """
if value is None: if value is None:
value = query value = query
if identity is None: if identity is None:
identity = query identity = query
if query_pos is not None: if query_pos is not None:
query = query + query_pos query = query + query_pos
if not self.batch_first: if not self.batch_first:
# change to (bs, num_query ,embed_dims) # change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2) query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2) value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape bs, num_query, _ = query.shape
bs, num_value, _ = value.shape bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value) value = self.value_proj(value)
if key_padding_mask is not None: if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0) value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1) value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view( sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view( attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points) bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1) attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query, attention_weights = attention_weights.view(bs, num_query,
self.num_heads, self.num_heads,
self.num_levels, self.num_levels,
self.num_points) self.num_points)
if reference_points.shape[-1] == 2: if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack( offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \ sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \ + sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :] / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4: elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \ + sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \ * reference_points[:, :, None, :, None, 2:] \
* 0.5 * 0.5
else: else:
raise ValueError( raise ValueError(
f'Last dim of reference_points must be' f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.') f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available(): if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunctionFp32.apply( output = MultiScaleDeformableAttnFunctionFp32.apply(
value, spatial_shapes, level_start_index, sampling_locations, value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step) attention_weights, self.im2col_step)
else: else:
output = multi_scale_deformable_attn_pytorch( output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations, value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step) attention_weights, self.im2col_step)
output = self.output_proj(output) output = self.output_proj(output)
if not self.batch_first: if not self.batch_first:
# (num_query, bs ,embed_dims) # (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
return self.dropout(output) + identity return self.dropout(output) + identity
\ No newline at end of file
#!/usr/bin/env bash #!/usr/bin/env bash
CONFIG=$1 CONFIG=$1
CHECKPOINT=$2 CHECKPOINT=$2
GPUS=$3 GPUS=$3
PORT=${PORT:-29500} PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
#!/usr/bin/env bash #!/usr/bin/env bash
CONFIG=$1 CONFIG=$1
GPUS=$2 GPUS=$2
PORT=${PORT:-29500} PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
import sys import sys
import os import os
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
from src.datasets.evaluation.vector_eval import VectorEvaluate from src.datasets.evaluation.vector_eval import VectorEvaluate
import argparse import argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Evaluate a submission file') description='Evaluate a submission file')
parser.add_argument('submission', parser.add_argument('submission',
help='submission file in pickle or json format to be evaluated') help='submission file in pickle or json format to be evaluated')
parser.add_argument('gt', parser.add_argument('gt',
help='gt annotation file') help='gt annotation file')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(args): def main(args):
evaluator = VectorEvaluate(args.gt, n_workers=0) evaluator = VectorEvaluate(args.gt, n_workers=0)
results = evaluator.evaluate(args.submission) results = evaluator.evaluate(args.submission)
print(results) print(results)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
main(args) main(args)
import os.path as osp import os.path as osp
import pickle import pickle
import shutil import shutil
import tempfile import tempfile
import time import time
import mmcv import mmcv
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.image import tensor2imgs from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results from mmdet.core import encode_mask_results
def single_gpu_test(model, def single_gpu_test(model,
data_loader, data_loader,
show=False, show=False,
out_dir=None, out_dir=None,
show_score_thr=0.3): show_score_thr=0.3):
model.eval() model.eval()
results = [] results = []
dataset = data_loader.dataset dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset)) prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) result = model(return_loss=False, rescale=True, **data)
batch_size = len(result) batch_size = len(result)
if show or out_dir: if show or out_dir:
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor): if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
img_tensor = data['img'][0] img_tensor = data['img'][0]
else: else:
img_tensor = data['img'][0].data[0] img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0] img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas) assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape'] h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :] img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1] ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h)) img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir: if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename']) out_file = osp.join(out_dir, img_meta['ori_filename'])
else: else:
out_file = None out_file = None
model.module.show_result( model.module.show_result(
img_show, img_show,
result[i], result[i],
show=show, show=show,
out_file=out_file, out_file=out_file,
score_thr=show_score_thr) score_thr=show_score_thr)
# encode mask results # encode mask results
if isinstance(result[0], tuple): if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results)) result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result] for bbox_results, mask_results in result]
results.extend(result) results.extend(result)
for _ in range(batch_size): for _ in range(batch_size):
prog_bar.update() prog_bar.update()
return results return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus. """Test model with multiple gpus.
This method tests model with multiple gpus and collects the results This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir' collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker. and collects them by the rank 0 worker.
Args: Args:
model (nn.Module): Model to be tested. model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader. data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode. different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results. gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns: Returns:
list: The prediction results. list: The prediction results.
""" """
model.eval() model.eval()
results = [] results = []
dataset = data_loader.dataset dataset = data_loader.dataset
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
if rank == 0: if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset)) prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases. time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) result = model(return_loss=False, rescale=True, **data)
# encode mask results # encode mask results
# if isinstance(result[0], tuple): # if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results)) # result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result] # for bbox_results, mask_results in result]
results.extend(result) results.extend(result)
if rank == 0: if rank == 0:
batch_size = len(result) batch_size = len(result)
for _ in range(batch_size * world_size): for _ in range(batch_size * world_size):
prog_bar.update() prog_bar.update()
# collect results from all ranks # collect results from all ranks
if gpu_collect: if gpu_collect:
results = collect_results_gpu(results, len(dataset)) results = collect_results_gpu(results, len(dataset))
else: else:
results = collect_results_cpu(results, len(dataset), tmpdir) results = collect_results_cpu(results, len(dataset), tmpdir)
return results return results
def collect_results_cpu(result_part, size, tmpdir=None): def collect_results_cpu(result_part, size, tmpdir=None):
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
# create a tmp dir if it is not specified # create a tmp dir if it is not specified
if tmpdir is None: if tmpdir is None:
MAX_LEN = 512 MAX_LEN = 512
# 32 is whitespace # 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ), dir_tensor = torch.full((MAX_LEN, ),
32, 32,
dtype=torch.uint8, dtype=torch.uint8,
device='cuda') device='cuda')
if rank == 0: if rank == 0:
mmcv.mkdir_or_exist('.dist_test') mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor( tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0) dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else: else:
mmcv.mkdir_or_exist(tmpdir) mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir # dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier() dist.barrier()
# collect all parts # collect all parts
if rank != 0: if rank != 0:
return None return None
else: else:
# load results of all parts from tmp dir # load results of all parts from tmp dir
part_list = [] part_list = []
for i in range(world_size): for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl') part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file)) part_list.append(mmcv.load(part_file))
# sort the results # sort the results
ordered_results = [] ordered_results = []
for res in zip(*part_list): for res in zip(*part_list):
ordered_results.extend(list(res)) ordered_results.extend(list(res))
# the dataloader may pad some samples # the dataloader may pad some samples
ordered_results = ordered_results[:size] ordered_results = ordered_results[:size]
# remove tmp dir # remove tmp dir
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
return ordered_results return ordered_results
def collect_results_gpu(result_part, size): def collect_results_gpu(result_part, size):
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
# dump result part to tensor with pickle # dump result part to tensor with pickle
part_tensor = torch.tensor( part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape # gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)] shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor) dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length # padding result part tensor to max length
shape_max = torch.tensor(shape_list).max() shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [ part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size) part_tensor.new_zeros(shape_max) for _ in range(world_size)
] ]
# gather all result part # gather all result part
dist.all_gather(part_recv_list, part_send) dist.all_gather(part_recv_list, part_send)
if rank == 0: if rank == 0:
part_list = [] part_list = []
for recv, shape in zip(part_recv_list, shape_list): for recv, shape in zip(part_recv_list, shape_list):
part_list.append( part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results # sort the results
ordered_results = [] ordered_results = []
for res in zip(*part_list): for res in zip(*part_list):
ordered_results.extend(list(res)) ordered_results.extend(list(res))
# the dataloader may pad some samples # the dataloader may pad some samples
ordered_results = ordered_results[:size] ordered_results = ordered_results[:size]
return ordered_results return ordered_results
import random import random
import warnings import warnings
import numpy as np import numpy as np
import torch import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer, Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner) build_runner)
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset, from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor) replace_ImageToTensor)
from mmdet.utils import get_root_logger from mmdet.utils import get_root_logger
def set_random_seed(seed, deterministic=False): def set_random_seed(seed, deterministic=False):
"""Set random seed. """Set random seed.
Args: Args:
seed (int): Seed to be used. seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False. to True and `torch.backends.cudnn.benchmark` to False.
Default: False. Default: False.
""" """
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
if deterministic: if deterministic:
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
def train_detector(model, def train_detector(model,
dataset, dataset,
cfg, cfg,
distributed=False, distributed=False,
validate=False, validate=False,
timestamp=None, timestamp=None,
meta=None): meta=None):
logger = get_root_logger(cfg.log_level) logger = get_root_logger(cfg.log_level)
# prepare data loaders # prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data: if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead') 'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data: if 'samples_per_gpu' in cfg.data:
logger.warning( logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments') f'={cfg.data.imgs_per_gpu} is used in this experiments')
else: else:
logger.warning( logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments') f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [ data_loaders = [
build_dataloader( build_dataloader(
ds, ds,
cfg.data.samples_per_gpu, cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu, cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed # cfg.gpus will be ignored if distributed
len(cfg.gpu_ids), len(cfg.gpu_ids),
dist=distributed, dist=distributed,
seed=cfg.seed) for ds in dataset seed=cfg.seed) for ds in dataset
] ]
# put model on gpus # put model on gpus
if distributed: if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False) find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in # Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel # torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel( model = MMDistributedDataParallel(
model.cuda(), model.cuda(),
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
broadcast_buffers=False, broadcast_buffers=False,
find_unused_parameters=find_unused_parameters) find_unused_parameters=find_unused_parameters)
else: else:
model = MMDataParallel( model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner # build runner
optimizer = build_optimizer(model, cfg.optimizer) optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg: if 'runner' not in cfg:
cfg.runner = { cfg.runner = {
'type': 'EpochBasedRunner', 'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs 'max_epochs': cfg.total_epochs
} }
warnings.warn( warnings.warn(
'config is now expected to have a `runner` section, ' 'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning) 'please set `runner` in your config.', UserWarning)
else: else:
if 'total_epochs' in cfg: if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner( runner = build_runner(
cfg.runner, cfg.runner,
default_args=dict( default_args=dict(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
work_dir=cfg.work_dir, work_dir=cfg.work_dir,
logger=logger, logger=logger,
meta=meta)) meta=meta))
# an ugly workaround to make .log and .log.json filenames the same # an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp runner.timestamp = timestamp
# fp16 setting # fp16 setting
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook( optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config: elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config) optimizer_config = OptimizerHook(**cfg.optimizer_config)
else: else:
optimizer_config = cfg.optimizer_config optimizer_config = cfg.optimizer_config
# register hooks # register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config, runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config, cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None)) cfg.get('momentum_config', None))
if distributed: if distributed:
if isinstance(runner, EpochBasedRunner): if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook()) runner.register_hook(DistSamplerSeedHook())
# register eval hooks # register eval hooks
if validate: if validate:
# Support batch_size > 1 in validation # Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
if val_samples_per_gpu > 1: if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle' # Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor( cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline) cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader( val_dataloader = build_dataloader(
val_dataset, val_dataset,
samples_per_gpu=val_samples_per_gpu, samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False) shuffle=False)
eval_cfg = cfg.get('evaluation', {}) eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks # user-defined hooks
if cfg.get('custom_hooks', None): if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \ assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}' f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks: for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \ assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \ 'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}' f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy() hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL') priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS) hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority) runner.register_hook(hook, priority=priority)
if cfg.resume_from: if cfg.resume_from:
runner.resume(cfg.resume_from) runner.resume(cfg.resume_from)
elif cfg.load_from: elif cfg.load_from:
runner.load_checkpoint(cfg.load_from) runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow) runner.run(data_loaders, cfg.workflow)
import argparse import argparse
import mmcv import mmcv
import os import os
import os.path as osp import os.path as osp
import torch import torch
import warnings import warnings
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model) wrap_fp16_model)
from mmdet3d.apis import single_gpu_test from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model from mmdet3d.models import build_model
from mmdet_test import multi_gpu_test from mmdet_test import multi_gpu_test
from mmdet_train import set_random_seed from mmdet_train import set_random_seed
from mmdet.datasets import replace_ImageToTensor from mmdet.datasets import replace_ImageToTensor
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model') description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path') parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file') parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('--split', type=str, required=True, help='which split to test on') parser.add_argument('--split', type=str, required=True, help='which split to test on')
parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument( parser.add_argument(
'--fuse-conv-bn', '--fuse-conv-bn',
action='store_true', action='store_true',
help='Whether to fuse conv and bn, this will slightly increase' help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed') 'the inference speed')
parser.add_argument( parser.add_argument(
'--format-only', '--format-only',
action='store_true', action='store_true',
help='Format the output results without perform evaluation. It is' help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and ' 'useful when you want to format the result to a specific format and '
'submit it to the test server') 'submit it to the test server')
parser.add_argument( parser.add_argument(
'--eval', '--eval',
action='store_true', action='store_true',
help='whether to run evaluation.') help='whether to run evaluation.')
parser.add_argument( parser.add_argument(
'--gpu-collect', '--gpu-collect',
action='store_true', action='store_true',
help='whether to use gpu to collect results.') help='whether to use gpu to collect results.')
parser.add_argument( parser.add_argument(
'--tmpdir', '--tmpdir',
help='tmp directory used for collecting results from multiple ' help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified') 'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--deterministic',
action='store_true', action='store_true',
help='whether to set deterministic options for CUDNN backend.') help='whether to set deterministic options for CUDNN backend.')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)
return args return args
def main(): def main():
args = parse_args() args = parse_args()
if args.split not in ['val', 'test']: if args.split not in ['val', 'test']:
raise ValueError('Please choose "val" or "test" split for testing') raise ValueError('Please choose "val" or "test" split for testing')
if (args.eval and args.format_only) or (not args.eval and not args.format_only): if (args.eval and args.format_only) or (not args.eval and not args.format_only):
raise ValueError('Please specify exactly one operation (eval/format) ' raise ValueError('Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"') 'with the argument "--eval" or "--format-only"')
if args.eval and args.split == 'test': if args.eval and args.split == 'test':
raise ValueError('Cannot evaluate on test set') raise ValueError('Cannot evaluate on test set')
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
# import modules from string list. # import modules from string list.
if cfg.get('custom_imports', None): if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports']) import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
# import modules from plguin/xx, registry will be updated # import modules from plguin/xx, registry will be updated
import sys import sys
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'): if hasattr(cfg, 'plugin'):
if cfg.plugin: if cfg.plugin:
import importlib import importlib
if hasattr(cfg, 'plugin_dir'): if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir): def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir) _module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/') _module_dir = _module_dir.split('/')
_module_path = _module_dir[0] _module_path = _module_dir[0]
for m in _module_dir[1:]: for m in _module_dir[1:]:
_module_path = _module_path + '.' + m _module_path = _module_path + '.' + m
print(f'importing {_module_path}/') print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list): if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,] plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs: for plugin_dir in plugin_dirs:
import_path(plugin_dir) import_path(plugin_dir)
else: else:
# import dir is the dirpath for the config file # import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config) _module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/') _module_dir = _module_dir.split('/')
_module_path = _module_dir[0] _module_path = _module_dir[0]
for m in _module_dir[1:]: for m in _module_dir[1:]:
_module_path = _module_path + '.' + m _module_path = _module_path + '.' + m
print(f'importing {_module_path}/') print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
cfg_data_dict = cfg.data.get(args.split) cfg_data_dict = cfg.data.get(args.split)
cfg.model.pretrained = None cfg.model.pretrained = None
# in case the test dataset is concatenated # in case the test dataset is concatenated
samples_per_gpu = 1 samples_per_gpu = 1
cfg_data_dict.test_mode = True cfg_data_dict.test_mode = True
samples_per_gpu = cfg_data_dict.pop('samples_per_gpu', 1) samples_per_gpu = cfg_data_dict.pop('samples_per_gpu', 1)
if samples_per_gpu > 1: if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle' # Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict.pipeline = replace_ImageToTensor( cfg_data_dict.pipeline = replace_ImageToTensor(
cfg_data_dict.pipeline) cfg_data_dict.pipeline)
# init distributed env first, since logger depends on the dist info. # init distributed env first, since logger depends on the dist info.
if args.launcher == 'none': if args.launcher == 'none':
distributed = False distributed = False
else: else:
distributed = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
# set random seeds # set random seeds
if args.seed is not None: if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic) set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader # build the dataloader
if args.work_dir is not None: if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None # update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None: elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None # use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
cfg_data_dict.work_dir = cfg.work_dir cfg_data_dict.work_dir = cfg.work_dir
print('work_dir: ',cfg.work_dir) print('work_dir: ',cfg.work_dir)
dataset = build_dataset(cfg_data_dict) dataset = build_dataset(cfg_data_dict)
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
samples_per_gpu=samples_per_gpu, samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed, dist=distributed,
shuffle=False) shuffle=False)
# build the model and load checkpoint # build the model and load checkpoint
cfg.model.train_cfg = None cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg')) model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn: if args.fuse_conv_bn:
model = fuse_conv_bn(model) model = fuse_conv_bn(model)
if not distributed: if not distributed:
model = MMDataParallel(model, device_ids=[0]) model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader) outputs = single_gpu_test(model, data_loader)
else: else:
model = MMDistributedDataParallel( model = MMDistributedDataParallel(
model.cuda(), model.cuda(),
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
broadcast_buffers=False) broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect) args.gpu_collect)
rank, _ = get_dist_info() rank, _ = get_dist_info()
if rank == 0: if rank == 0:
if args.format_only: if args.format_only:
dataset.format_results(outputs, prefix=cfg.work_dir) dataset.format_results(outputs, prefix=cfg.work_dir)
elif args.eval: elif args.eval:
print('start evaluation!') print('start evaluation!')
print(dataset.evaluate(outputs)) print(dataset.evaluate(outputs))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
from __future__ import division from __future__ import division
import argparse import argparse
import copy import copy
import mmcv import mmcv
import os import os
import time import time
import torch import torch
import warnings import warnings
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from os import path as osp from os import path as osp
from mmdet import __version__ as mmdet_version from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset from mmdet3d.datasets import build_dataset
from mmdet3d.utils import collect_env, get_root_logger from mmdet3d.utils import collect_env, get_root_logger
from mmseg import __version__ as mmseg_version from mmseg import __version__ as mmseg_version
# warper # warper
from mmdet_train import set_random_seed from mmdet_train import set_random_seed
# from builder import build_model # from builder import build_model
from mmdet3d.models import build_model from mmdet3d.models import build_model
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Train a detector') parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument( parser.add_argument(
'--resume-from', help='the checkpoint file to resume from') '--resume-from', help='the checkpoint file to resume from')
parser.add_argument( parser.add_argument(
'--no-validate', '--no-validate',
action='store_true', action='store_true',
help='whether not to evaluate the checkpoint during training') help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group() group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument( group_gpus.add_argument(
'--gpus', '--gpus',
type=int, type=int,
help='number of gpus to use ' help='number of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
group_gpus.add_argument( group_gpus.add_argument(
'--gpu-ids', '--gpu-ids',
type=int, type=int,
nargs='+', nargs='+',
help='ids of gpus to use ' help='ids of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--deterministic',
action='store_true', action='store_true',
help='whether to set deterministic options for CUDNN backend.') help='whether to set deterministic options for CUDNN backend.')
parser.add_argument( parser.add_argument(
'--options', '--options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), ' 'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.') 'change to --cfg-options instead.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument( parser.add_argument(
'--autoscale-lr', '--autoscale-lr',
action='store_true', action='store_true',
help='automatically scale lr with the number of gpus') help='automatically scale lr with the number of gpus')
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options: if args.options and args.cfg_options:
raise ValueError( raise ValueError(
'--options and --cfg-options cannot be both specified, ' '--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options') '--options is deprecated in favor of --cfg-options')
if args.options: if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options') warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options args.cfg_options = args.options
return args return args
def main(): def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# import modules from string list. # import modules from string list.
if cfg.get('custom_imports', None): if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports']) import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
# import modules, registry will be updated # import modules, registry will be updated
import sys import sys
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'): if hasattr(cfg, 'plugin'):
if cfg.plugin: if cfg.plugin:
import importlib import importlib
if hasattr(cfg, 'plugin_dir'): if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir): def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir) _module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/') _module_dir = _module_dir.split('/')
_module_path = _module_dir[0] _module_path = _module_dir[0]
for m in _module_dir[1:]: for m in _module_dir[1:]:
_module_path = _module_path + '.' + m _module_path = _module_path + '.' + m
print(f'importing {_module_path}/') print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list): if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,] plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs: for plugin_dir in plugin_dirs:
import_path(plugin_dir) import_path(plugin_dir)
else: else:
# import dir is the dirpath for the config file # import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config) _module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/') _module_dir = _module_dir.split('/')
_module_path = _module_dir[0] _module_path = _module_dir[0]
for m in _module_dir[1:]: for m in _module_dir[1:]:
_module_path = _module_path + '.' + m _module_path = _module_path + '.' + m
print(f'importing {_module_path}/') print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
# work_dir is determined in this priority: CLI > segment in file > filename # work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None: if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None # update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None: elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None # use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None: if args.resume_from is not None:
cfg.resume_from = args.resume_from cfg.resume_from = args.resume_from
if args.gpu_ids is not None: if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids cfg.gpu_ids = args.gpu_ids
else: else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.autoscale_lr: if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677) # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info. # init distributed env first, since logger depends on the dist info.
if args.launcher == 'none': if args.launcher == 'none':
distributed = False distributed = False
else: else:
distributed = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode # re-set gpu_ids with distributed training mode
_, world_size = get_dist_info() _, world_size = get_dist_info()
cfg.gpu_ids = range(world_size) cfg.gpu_ids = range(world_size)
# create work_dir # create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config # dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps # init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
# specify logger name, if we still use 'mmdet', the output info will be # specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file # filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model # TODO: ugly workaround to judge whether we are training det or seg model
if cfg.model.type in ['EncoderDecoder3D']: if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg' logger_name = 'mmseg'
else: else:
logger_name = 'mmdet' logger_name = 'mmdet'
logger = get_root_logger( logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name) log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as # init the meta dict to record some important information such as
# environment info and seed, which will be logged # environment info and seed, which will be logged
meta = dict() meta = dict()
# log env info # log env info
env_info_dict = collect_env() env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n' dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line) dash_line)
meta['env_info'] = env_info meta['env_info'] = env_info
meta['config'] = cfg.pretty_text meta['config'] = cfg.pretty_text
# log some basic info # log some basic info
logger.info(f'Distributed training: {distributed}') logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}') logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds # set random seeds
if args.seed is not None: if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, ' logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}') f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic) set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed cfg.seed = args.seed
meta['seed'] = args.seed meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config) meta['exp_name'] = osp.basename(args.config)
model = build_model( model = build_model(
cfg.model, cfg.model,
train_cfg=cfg.get('train_cfg'), train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')) test_cfg=cfg.get('test_cfg'))
model.init_weights() model.init_weights()
logger.info(f'Model:\n{model}') logger.info(f'Model:\n{model}')
cfg.data.train.work_dir = cfg.work_dir cfg.data.train.work_dir = cfg.work_dir
cfg.data.val.work_dir = cfg.work_dir cfg.data.val.work_dir = cfg.work_dir
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val) val_dataset = copy.deepcopy(cfg.data.val)
# in case we use a dataset wrapper # in case we use a dataset wrapper
if 'dataset' in cfg.data.train: if 'dataset' in cfg.data.train:
val_dataset.pipeline = cfg.data.train.dataset.pipeline val_dataset.pipeline = cfg.data.train.dataset.pipeline
else: else:
val_dataset.pipeline = cfg.data.train.pipeline val_dataset.pipeline = cfg.data.train.pipeline
# set test_mode=False here in deep copied config # set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later # which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset.test_mode = False val_dataset.test_mode = False
datasets.append(build_dataset(val_dataset)) datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None: if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in # save mmdet version, config file content and class names in
# checkpoints as meta data # checkpoints as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=mmdet_version, mmdet_version=mmdet_version,
mmseg_version=mmseg_version, mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version, mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text, config=cfg.pretty_text,
CLASSES=None, CLASSES=None,
PALETTE=datasets[0].PALETTE # for segmentors PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None) if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience # add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES # model.CLASSES = datasets[0].CLASSES
train_model( train_model(
model, model,
datasets, datasets,
cfg, cfg,
distributed=distributed, distributed=distributed,
validate=(not args.no_validate), validate=(not args.no_validate),
timestamp=timestamp, timestamp=timestamp,
meta=meta) meta=meta)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import os.path as osp import os.path as osp
import os import os
import numpy as np import numpy as np
import copy import copy
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from PIL import Image from PIL import Image
from shapely.geometry import LineString from shapely.geometry import LineString
def remove_nan_values(uv): def remove_nan_values(uv):
is_u_valid = np.logical_not(np.isnan(uv[:, 0])) is_u_valid = np.logical_not(np.isnan(uv[:, 0]))
is_v_valid = np.logical_not(np.isnan(uv[:, 1])) is_v_valid = np.logical_not(np.isnan(uv[:, 1]))
is_uv_valid = np.logical_and(is_u_valid, is_v_valid) is_uv_valid = np.logical_and(is_u_valid, is_v_valid)
uv_valid = uv[is_uv_valid] uv_valid = uv[is_uv_valid]
return uv_valid return uv_valid
def points_ego2img(pts_ego, extrinsics, intrinsics): def points_ego2img(pts_ego, extrinsics, intrinsics):
pts_ego_4d = np.concatenate([pts_ego, np.ones([len(pts_ego), 1])], axis=-1) pts_ego_4d = np.concatenate([pts_ego, np.ones([len(pts_ego), 1])], axis=-1)
pts_cam_4d = extrinsics @ pts_ego_4d.T pts_cam_4d = extrinsics @ pts_ego_4d.T
uv = (intrinsics @ pts_cam_4d[:3, :]).T uv = (intrinsics @ pts_cam_4d[:3, :]).T
uv = remove_nan_values(uv) uv = remove_nan_values(uv)
depth = uv[:, 2] depth = uv[:, 2]
uv = uv[:, :2] / uv[:, 2].reshape(-1, 1) uv = uv[:, :2] / uv[:, 2].reshape(-1, 1)
return uv, depth return uv, depth
def interp_fixed_dist(line, sample_dist): def interp_fixed_dist(line, sample_dist):
''' Interpolate a line at fixed interval. ''' Interpolate a line at fixed interval.
Args: Args:
line (LineString): line line (LineString): line
sample_dist (float): sample interval sample_dist (float): sample interval
Returns: Returns:
points (array): interpolated points, shape (N, 2) points (array): interpolated points, shape (N, 2)
''' '''
distances = list(np.arange(sample_dist, line.length, sample_dist)) distances = list(np.arange(sample_dist, line.length, sample_dist))
# make sure to sample at least two points when sample_dist > line.length # make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,] distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords) sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze() for distance in distances]).squeeze()
return sampled_points return sampled_points
def draw_polyline_ego_on_img(polyline_ego, img_bgr, extrinsics, intrinsics, color_bgr, thickness): def draw_polyline_ego_on_img(polyline_ego, img_bgr, extrinsics, intrinsics, color_bgr, thickness):
# if 2-dimension, assume z=0 # if 2-dimension, assume z=0
if polyline_ego.shape[1] == 2: if polyline_ego.shape[1] == 2:
zeros = np.zeros((polyline_ego.shape[0], 1)) zeros = np.zeros((polyline_ego.shape[0], 1))
polyline_ego = np.concatenate([polyline_ego, zeros], axis=1) polyline_ego = np.concatenate([polyline_ego, zeros], axis=1)
polyline_ego = interp_fixed_dist(line=LineString(polyline_ego), sample_dist=0.2) polyline_ego = interp_fixed_dist(line=LineString(polyline_ego), sample_dist=0.2)
uv, depth = points_ego2img(polyline_ego, extrinsics, intrinsics) uv, depth = points_ego2img(polyline_ego, extrinsics, intrinsics)
h, w, c = img_bgr.shape h, w, c = img_bgr.shape
is_valid_x = np.logical_and(0 <= uv[:, 0], uv[:, 0] < w - 1) is_valid_x = np.logical_and(0 <= uv[:, 0], uv[:, 0] < w - 1)
is_valid_y = np.logical_and(0 <= uv[:, 1], uv[:, 1] < h - 1) is_valid_y = np.logical_and(0 <= uv[:, 1], uv[:, 1] < h - 1)
is_valid_z = depth > 0 is_valid_z = depth > 0
is_valid_points = np.logical_and.reduce([is_valid_x, is_valid_y, is_valid_z]) is_valid_points = np.logical_and.reduce([is_valid_x, is_valid_y, is_valid_z])
if is_valid_points.sum() == 0: if is_valid_points.sum() == 0:
return return
tmp_list = [] tmp_list = []
for i, valid in enumerate(is_valid_points): for i, valid in enumerate(is_valid_points):
if valid: if valid:
tmp_list.append(uv[i]) tmp_list.append(uv[i])
else: else:
if len(tmp_list) >= 2: if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list) tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32) tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2( draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector), copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool), valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr, image=img_bgr,
color=color_bgr, color=color_bgr,
thickness_px=thickness, thickness_px=thickness,
) )
tmp_list = [] tmp_list = []
if len(tmp_list) >= 2: if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list) tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32) tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2( draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector), copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool), valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr, image=img_bgr,
color=color_bgr, color=color_bgr,
thickness_px=thickness, thickness_px=thickness,
) )
# uv = np.round(uv[is_valid_points]).astype(np.int32) # uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2( # draw_visible_polyline_cv2(
# copy.deepcopy(uv), # copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool), # valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr, # image=img_bgr,
# color=color_bgr, # color=color_bgr,
# thickness_px=thickness, # thickness_px=thickness,
# ) # )
def draw_visible_polyline_cv2(line, valid_pts_bool, image, color, thickness_px): def draw_visible_polyline_cv2(line, valid_pts_bool, image, color, thickness_px):
"""Draw a polyline onto an image using given line segments. """Draw a polyline onto an image using given line segments.
Args: Args:
line: Array of shape (K, 2) representing the coordinates of line. line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering. valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid. For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered. Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline. thickness_px: thickness (in pixels) to use when rendering the polyline.
""" """
line = np.round(line).astype(int) # type: ignore line = np.round(line).astype(int) # type: ignore
for i in range(len(line) - 1): for i in range(len(line) - 1):
if (not valid_pts_bool[i]) or (not valid_pts_bool[i + 1]): if (not valid_pts_bool[i]) or (not valid_pts_bool[i + 1]):
continue continue
x1 = line[i][0] x1 = line[i][0]
y1 = line[i][1] y1 = line[i][1]
x2 = line[i + 1][0] x2 = line[i + 1][0]
y2 = line[i + 1][1] y2 = line[i + 1][1]
# Use anti-aliasing (AA) for curves # Use anti-aliasing (AA) for curves
image = cv2.line(image, pt1=(x1, y1), pt2=(x2, y2), color=color, thickness=thickness_px, lineType=cv2.LINE_AA) image = cv2.line(image, pt1=(x1, y1), pt2=(x2, y2), color=color, thickness=thickness_px, lineType=cv2.LINE_AA)
COLOR_MAPS_BGR = { COLOR_MAPS_BGR = {
# bgr colors # bgr colors
'divider': (0, 0, 255), 'divider': (0, 0, 255),
'boundary': (0, 255, 0), 'boundary': (0, 255, 0),
'ped_crossing': (255, 0, 0), 'ped_crossing': (255, 0, 0),
'centerline': (51, 183, 255), 'centerline': (51, 183, 255),
'drivable_area': (171, 255, 255) 'drivable_area': (171, 255, 255)
} }
COLOR_MAPS_PLT = { COLOR_MAPS_PLT = {
'divider': 'r', 'divider': 'r',
'boundary': 'g', 'boundary': 'g',
'ped_crossing': 'b', 'ped_crossing': 'b',
'centerline': 'orange', 'centerline': 'orange',
'drivable_area': 'y', 'drivable_area': 'y',
} }
CAM_NAMES_AV2 = ['ring_front_center', 'ring_front_right', 'ring_front_left', CAM_NAMES_AV2 = ['ring_front_center', 'ring_front_right', 'ring_front_left',
'ring_rear_right','ring_rear_left', 'ring_side_right', 'ring_side_left', 'ring_rear_right','ring_rear_left', 'ring_side_right', 'ring_side_left',
] ]
class Renderer(object): class Renderer(object):
"""Render map elements on image views. """Render map elements on image views.
Args: Args:
roi_size (tuple): bev range roi_size (tuple): bev range
""" """
def __init__(self, roi_size): def __init__(self, roi_size):
self.roi_size = roi_size self.roi_size = roi_size
def render_bev_from_vectors(self, vectors, out_dir): def render_bev_from_vectors(self, vectors, out_dir):
'''Plot vectorized map elements on BEV. '''Plot vectorized map elements on BEV.
Args: Args:
vectors (dict): dict of vectorized map elements. vectors (dict): dict of vectorized map elements.
out_dir (str): output directory out_dir (str): output directory
''' '''
car_img = Image.open('resources/images/car.png') car_img = Image.open('resources/images/car.png')
map_path = os.path.join(out_dir, 'map.jpg') map_path = os.path.join(out_dir, 'map.jpg')
plt.figure(figsize=(self.roi_size[0], self.roi_size[1])) plt.figure(figsize=(self.roi_size[0], self.roi_size[1]))
plt.xlim(-self.roi_size[0]/2 - 1, self.roi_size[0]/2 + 1) plt.xlim(-self.roi_size[0]/2 - 1, self.roi_size[0]/2 + 1)
plt.ylim(-self.roi_size[1]/2 - 1, self.roi_size[1]/2 + 1) plt.ylim(-self.roi_size[1]/2 - 1, self.roi_size[1]/2 + 1)
plt.axis('off') plt.axis('off')
plt.imshow(car_img, extent=[-1.5, 1.5, -1.2, 1.2]) plt.imshow(car_img, extent=[-1.5, 1.5, -1.2, 1.2])
for cat, vector_list in vectors.items(): for cat, vector_list in vectors.items():
color = COLOR_MAPS_PLT[cat] color = COLOR_MAPS_PLT[cat]
for vector in vector_list: for vector in vector_list:
pts = np.array(vector)[:, :2] pts = np.array(vector)[:, :2]
x = np.array([pt[0] for pt in pts]) x = np.array([pt[0] for pt in pts])
y = np.array([pt[1] for pt in pts]) y = np.array([pt[1] for pt in pts])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color, # plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1) # scale_units='xy', scale=1)
plt.plot(x, y, color=color, linewidth=5, marker='o', linestyle='-', markersize=20) plt.plot(x, y, color=color, linewidth=5, marker='o', linestyle='-', markersize=20)
plt.savefig(map_path, bbox_inches='tight', dpi=40) plt.savefig(map_path, bbox_inches='tight', dpi=40)
plt.close() plt.close()
def render_camera_views_from_vectors(self, vectors, imgs, extrinsics, def render_camera_views_from_vectors(self, vectors, imgs, extrinsics,
intrinsics, thickness, out_dir): intrinsics, thickness, out_dir):
'''Project vectorized map elements to camera views. '''Project vectorized map elements to camera views.
Args: Args:
vectors (dict): dict of vectorized map elements. vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color. imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4) extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3) intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images. thickness (int): thickness of lines to draw on images.
out_dir (str): output directory out_dir (str): output directory
''' '''
for i in range(len(imgs)): for i in range(len(imgs)):
img = imgs[i] img = imgs[i]
extrinsic = extrinsics[i] extrinsic = extrinsics[i]
intrinsic = intrinsics[i] intrinsic = intrinsics[i]
img_bgr = copy.deepcopy(img) img_bgr = copy.deepcopy(img)
for cat, vector_list in vectors.items(): for cat, vector_list in vectors.items():
color = COLOR_MAPS_BGR[cat] color = COLOR_MAPS_BGR[cat]
for vector in vector_list: for vector in vector_list:
img_bgr = np.ascontiguousarray(img_bgr) img_bgr = np.ascontiguousarray(img_bgr)
vector_array = np.array(vector) vector_array = np.array(vector)
if vector_array.shape[1] > 3: if vector_array.shape[1] > 3:
vector_array = vector_array[:, :3] vector_array = vector_array[:, :3]
draw_polyline_ego_on_img(vector_array, img_bgr, extrinsic, intrinsic, draw_polyline_ego_on_img(vector_array, img_bgr, extrinsic, intrinsic,
color, thickness) color, thickness)
out_path = osp.join(out_dir, CAM_NAMES_AV2[i]) + '.jpg' out_path = osp.join(out_dir, CAM_NAMES_AV2[i]) + '.jpg'
cv2.imwrite(out_path, img_bgr) cv2.imwrite(out_path, img_bgr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment