Commit f3b13cad authored by yeshenglong1's avatar yeshenglong1
Browse files

UpDate README.md

parent 0797920d
import copy
import torch
import torch.nn as nn
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.models import HEADS, build_head, build_loss
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from .base_map_head import BaseMapHead
import numpy as np
from ..augmentation.sythesis_det import NoiseSythesis
@HEADS.register_module(force=True)
class DGHead(BaseMapHead):
def __init__(self,
det_net_cfg=dict(),
gen_net_cfg=dict(),
loss_vert=dict(),
loss_face=dict(),
max_num_vertices=90,
top_p_gen_model=0.9,
sync_cls_avg_factor=True,
augmentation=False,
augmentation_kwargs=None,
joint_training=False,
**kwargs):
super().__init__()
# Heads
self.det_net = build_head(det_net_cfg)
self.gen_net = build_head(gen_net_cfg)
self.coord_dim = self.gen_net.coord_dim
# Loss params
self.bg_cls_weight = 1.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.max_num_vertices = max_num_vertices
self.top_p_gen_model = top_p_gen_model
self.fp16_enabled = False
self.augmentation = None
if augmentation:
augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size})
self.augmentation = NoiseSythesis(**augmentation_kwargs)
self.joint_training = joint_training
def forward(self, batch, img_metas=None, **kwargs):
'''
Args:
Returns:
outs (Dict):
'''
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs):
''' we use teacher force strategy'''
bbox_dict = self.det_net(context=context)
outs = dict(
bbox=bbox_dict,
)
losses_dict, det_match_idxs, det_match_gt_idxs = \
self.loss_det(batch, outs)
if only_det: return outs, losses_dict
if self.augmentation is not None:
polylines, bbox_flat =\
self.augmentation(batch['gen'],simple_aug=True)
if bbox_flat is None:
bbox_flat = batch['gen']['bbox_flat']
gen_input = dict(
lines_bs_idx=batch['gen']['lines_bs_idx'],
lines_cls=batch['gen']['lines_cls'],
bbox_flat=bbox_flat,
polylines=polylines,
polyline_masks=batch['gen']['polyline_masks']
)
else:
gen_input = batch['gen']
if self.joint_training:
# for down stream polyline
if 'lines' in bbox_dict[-1]:
# for fix anchor
pred_bbox = bbox_dict[-1]['lines'].detach()
elif 'bboxs' in bbox_dict[-1]:
# for rpv
pred_bbox = bbox_dict[-1]['bboxs'].detach()
else:
raise NotImplementedError
# changed to original gt order.
det_match_idx = det_match_idxs[-1]
det_match_gt_idx = det_match_gt_idxs[-1]
_bboxs = []
for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)):
_bboxs.append(bbox[match_idx])
_bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
_bboxs = torch.cat(_bboxs, dim=0)
# quantize the data
_bboxs = \
torch.round(_bboxs).type(torch.int32)
# gen_input['bbox_flat'] = _bboxs
remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)]
# for data efficient
for k in gen_input.keys():
if k == 'bbox_flat':
gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0)
else:
gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0)
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(gen_input, context=context)
outs.update(dict(
polylines=poly_dict,
))
if self.joint_training:
for k in batch['gen'].keys():
batch['gen'][k] = \
torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0)
gen_losses_dict = \
self.loss_gen(batch, outs)
losses_dict.update(gen_losses_dict)
return outs, losses_dict
def loss_det(self, gt: dict, pred: dict):
loss_dict = {}
# det
det_loss_dict, det_match_idx, det_match_gt_idx = \
self.det_net.loss(gt['det'], pred['bbox'])
for k, v in det_loss_dict.items():
loss_dict['det_'+k] = v
return loss_dict, det_match_idx, det_match_gt_idx
def loss_gen(self, gt: dict, pred: dict):
loss_dict = {}
# gen
gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])
for k, v in gen_loss_dict.items():
loss_dict['gen_'+k] = v
return loss_dict
def loss(self, gt: dict, pred: dict):
pass
@torch.no_grad()
def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs):
'''
num_samples_batch: number of sample per batch (batch size)
'''
outs = {}
bbox_dict = self.det_net(context=context)
bbox_dict = self.det_net.post_process(bbox_dict)
outs.update(bbox_dict)
if len(outs['lines_bs_idx']) == 0:
return None
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(outs,
context=context,
# max_sample_length=self.max_num_vertices,
max_sample_length=64,
top_p=self.top_p_gen_model,
gt_condition=gt_condition)
outs.update(poly_dict)
return outs
def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs):
'''
Args:
XXX
Outs:
XXX
'''
range_size = self.gen_net.canvas_size.cpu().numpy()
coord_dim = self.gen_net.coord_dim
gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen'
ret_list = []
for batch_idx in range(len(tokens)):
ret_dict_single = {}
# bbox
det_gt = None
if gts is not None:
det_gt, rec_groundtruth = pack_groundtruth(
batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim)
bbox_res = {
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
'token': tokens[batch_idx],
'scores': preds['scores'][batch_idx].detach().cpu().numpy(),
'labels': preds['labels'][batch_idx].detach().cpu().numpy(),
}
ret_dict_single.update(bbox_res)
# for gen results.
batch2seq = np.nonzero(
preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_dict_single.update({
'nline': len(batch2seq),
'vectors': []
})
for i in batch2seq:
pre = preds['polylines'][i].detach().cpu().numpy()
pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
valid_idx = np.nonzero(pre_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_dict_single['vectors'].append(line)
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list.append(ret_dict_single)
return ret_list
def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2):
if 'keypoints' in gts['det']:
gt_bbox = \
gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
else:
gt_bbox = \
gts['det']['bbox'][batch_idx].detach().cpu().numpy()
det_gt = {
'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(),
'bboxes': gt_bbox,
}
batch2seq = np.nonzero(
gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_groundtruth = {
'token': tokens[batch_idx],
'nline': len(batch2seq),
'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(),
'lines': [],
}
for i in batch2seq:
gt_line =\
gts['gen']['polylines'].detach().cpu().numpy()[i]
gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
if gen_net_name == 'gen_gmm':
valid_idx = np.nonzero(gt_msk)[0]
else:
valid_idx = np.nonzero(gt_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_groundtruth['lines'].append(line)
return det_gt, ret_groundtruth
import copy
import torch
import torch.nn as nn
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.models import HEADS, build_head, build_loss
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from .base_map_head import BaseMapHead
import numpy as np
from ..augmentation.sythesis_det import NoiseSythesis
@HEADS.register_module(force=True)
class DGHead(BaseMapHead):
def __init__(self,
det_net_cfg=dict(),
gen_net_cfg=dict(),
loss_vert=dict(),
loss_face=dict(),
max_num_vertices=90,
top_p_gen_model=0.9,
sync_cls_avg_factor=True,
augmentation=False,
augmentation_kwargs=None,
joint_training=False,
**kwargs):
super().__init__()
# Heads
self.det_net = build_head(det_net_cfg)
self.gen_net = build_head(gen_net_cfg)
self.coord_dim = self.gen_net.coord_dim
# Loss params
self.bg_cls_weight = 1.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.max_num_vertices = max_num_vertices
self.top_p_gen_model = top_p_gen_model
self.fp16_enabled = False
self.augmentation = None
if augmentation:
augmentation_kwargs.update({'canvas_size':gen_net_cfg.canvas_size})
self.augmentation = NoiseSythesis(**augmentation_kwargs)
self.joint_training = joint_training
def forward(self, batch, img_metas=None, **kwargs):
'''
Args:
Returns:
outs (Dict):
'''
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs):
''' we use teacher force strategy'''
bbox_dict = self.det_net(context=context)
outs = dict(
bbox=bbox_dict,
)
losses_dict, det_match_idxs, det_match_gt_idxs = \
self.loss_det(batch, outs)
if only_det: return outs, losses_dict
if self.augmentation is not None:
polylines, bbox_flat =\
self.augmentation(batch['gen'],simple_aug=True)
if bbox_flat is None:
bbox_flat = batch['gen']['bbox_flat']
gen_input = dict(
lines_bs_idx=batch['gen']['lines_bs_idx'],
lines_cls=batch['gen']['lines_cls'],
bbox_flat=bbox_flat,
polylines=polylines,
polyline_masks=batch['gen']['polyline_masks']
)
else:
gen_input = batch['gen']
if self.joint_training:
# for down stream polyline
if 'lines' in bbox_dict[-1]:
# for fix anchor
pred_bbox = bbox_dict[-1]['lines'].detach()
elif 'bboxs' in bbox_dict[-1]:
# for rpv
pred_bbox = bbox_dict[-1]['bboxs'].detach()
else:
raise NotImplementedError
# changed to original gt order.
det_match_idx = det_match_idxs[-1]
det_match_gt_idx = det_match_gt_idxs[-1]
_bboxs = []
for i, (match_idx, bbox) in enumerate(zip(det_match_idx,pred_bbox)):
_bboxs.append(bbox[match_idx])
_bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
_bboxs = torch.cat(_bboxs, dim=0)
# quantize the data
_bboxs = \
torch.round(_bboxs).type(torch.int32)
# gen_input['bbox_flat'] = _bboxs
remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0]*0.2)]
# for data efficient
for k in gen_input.keys():
if k == 'bbox_flat':
gen_input[k] = torch.cat((_bboxs,gen_input[k][remain_idx]),dim=0)
else:
gen_input[k] = torch.cat((gen_input[k],gen_input[k][remain_idx]),dim=0)
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(gen_input, context=context)
outs.update(dict(
polylines=poly_dict,
))
if self.joint_training:
for k in batch['gen'].keys():
batch['gen'][k] = \
torch.cat((batch['gen'][k],batch['gen'][k][remain_idx]),dim=0)
gen_losses_dict = \
self.loss_gen(batch, outs)
losses_dict.update(gen_losses_dict)
return outs, losses_dict
def loss_det(self, gt: dict, pred: dict):
loss_dict = {}
# det
det_loss_dict, det_match_idx, det_match_gt_idx = \
self.det_net.loss(gt['det'], pred['bbox'])
for k, v in det_loss_dict.items():
loss_dict['det_'+k] = v
return loss_dict, det_match_idx, det_match_gt_idx
def loss_gen(self, gt: dict, pred: dict):
loss_dict = {}
# gen
gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])
for k, v in gen_loss_dict.items():
loss_dict['gen_'+k] = v
return loss_dict
def loss(self, gt: dict, pred: dict):
pass
@torch.no_grad()
def inference(self, batch: dict={}, context: dict={}, gt_condition=False, **kwargs):
'''
num_samples_batch: number of sample per batch (batch size)
'''
outs = {}
bbox_dict = self.det_net(context=context)
bbox_dict = self.det_net.post_process(bbox_dict)
outs.update(bbox_dict)
if len(outs['lines_bs_idx']) == 0:
return None
if isinstance(context['bev_embeddings'],tuple):
context['bev_embeddings'] = context['bev_embeddings'][0]
poly_dict = self.gen_net(outs,
context=context,
# max_sample_length=self.max_num_vertices,
max_sample_length=64,
top_p=self.top_p_gen_model,
gt_condition=gt_condition)
outs.update(poly_dict)
return outs
def post_process(self, preds: dict, tokens, gts:dict=None, **kwargs):
'''
Args:
XXX
Outs:
XXX
'''
range_size = self.gen_net.canvas_size.cpu().numpy()
coord_dim = self.gen_net.coord_dim
gen_net_name = self.gen_net.name if hasattr(self.gen_net,'name') else 'gen'
ret_list = []
for batch_idx in range(len(tokens)):
ret_dict_single = {}
# bbox
det_gt = None
if gts is not None:
det_gt, rec_groundtruth = pack_groundtruth(
batch_idx,gts,tokens,range_size,gen_net_name,coord_dim=coord_dim)
bbox_res = {
# 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
# 'det_gt': det_gt,
'token': tokens[batch_idx],
'scores': preds['scores'][batch_idx].detach().cpu().numpy(),
'labels': preds['labels'][batch_idx].detach().cpu().numpy(),
}
ret_dict_single.update(bbox_res)
# for gen results.
batch2seq = np.nonzero(
preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_dict_single.update({
'nline': len(batch2seq),
'vectors': []
})
for i in batch2seq:
pre = preds['polylines'][i].detach().cpu().numpy()
pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
valid_idx = np.nonzero(pre_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_dict_single['vectors'].append(line)
# if gts is not None:
# ret_dict_single['groundTruth'] = rec_groundtruth
ret_list.append(ret_dict_single)
return ret_list
def pack_groundtruth(batch_idx,gts,tokens,range_size,gen_net_name='gen',coord_dim=2):
if 'keypoints' in gts['det']:
gt_bbox = \
gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
else:
gt_bbox = \
gts['det']['bbox'][batch_idx].detach().cpu().numpy()
det_gt = {
'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(),
'bboxes': gt_bbox,
}
batch2seq = np.nonzero(
gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
ret_groundtruth = {
'token': tokens[batch_idx],
'nline': len(batch2seq),
'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(),
'lines': [],
}
for i in batch2seq:
gt_line =\
gts['gen']['polylines'].detach().cpu().numpy()[i]
gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
if gen_net_name == 'gen_gmm':
valid_idx = np.nonzero(gt_msk)[0]
else:
valid_idx = np.nonzero(gt_msk)[0][:-1]
# From [200,1] to [199,0] to (1,0)
line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size-1)
ret_groundtruth['lines'].append(line)
return det_gt, ret_groundtruth
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear
from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical
from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean)
from mmdet.models import HEADS
from .detr_bbox import DETRBboxHead
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
@HEADS.register_module(force=True)
class MapElementDetector(nn.Module):
def __init__(self,
canvas_size=(400, 200),
discrete_output=False,
separate_detect=False,
mode='xyxy',
bbox_size=None,
coord_dim=2,
kp_coord_dim=2,
num_classes=3,
in_channels=128,
num_query=100,
max_lines=50,
score_thre=0.2,
num_reg_fcs=2,
num_points=100,
iterative=False,
patch_size=None,
sync_cls_avg_factor=True,
transformer: dict = None,
positional_encoding: dict = None,
loss_cls: dict = None,
loss_reg: dict = None,
train_cfg: dict = None,):
super().__init__()
assigner = train_cfg['assigner']
self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.train_cfg = train_cfg
self.max_lines = max_lines
self.score_thre = score_thre
self.num_query = num_query
self.in_channels = in_channels
self.num_classes = num_classes
self.num_points = num_points
# branch
# if loss_cls.use_sigmoid:
if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes+1
self.iterative = iterative
self.num_reg_fcs = num_reg_fcs
self._build_transformer(transformer, positional_encoding)
# loss params
self.loss_cls = build_loss(loss_cls)
self.bg_cls_weight = 0.1
if self.loss_cls.use_sigmoid:
self.bg_cls_weight = 0.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.reg_loss = build_loss(loss_reg)
self.separate_detect = separate_detect
self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2
if bbox_size is not None:
self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz
self.kp_coord_dim = kp_coord_dim
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# add reg, cls head for each decoder layer
self._init_layers()
self._init_branch()
self.init_weights()
self._init_embedding()
def _init_layers(self):
"""Initialize some layer."""
self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims)
def _init_embedding(self):
self.label_embed = nn.Embedding(
self.num_classes, self.embed_dims)
self.img_coord_embed = nn.Linear(2, self.embed_dims)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims*2)
# for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding( self.bbox_size,
self.embed_dims*2)
def _init_branch(self,):
"""Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU())
if self.discrete_output:
reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,))
else:
reg_branch.append(nn.Linear(
self.embed_dims, self.coord_dim, bias=True,))
reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers
if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred)
else:
reg_branch = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
fc_cls = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.pre_branches = nn.ModuleDict([
('cls', fc_cls),
('reg', reg_branch), ])
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
for p in self.input_proj.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.transformer.init_weights()
# init prediction branch
for k, v in self.pre_branches.items():
for param in v.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
# focal loss init
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
# for last layer
if isinstance(self.pre_branches['cls'], nn.ModuleList):
for m in self.pre_branches['cls']:
nn.init.constant_(m.bias, bias_init)
else:
m = self.pre_branches['cls']
nn.init.constant_(m.bias, bias_init)
def _build_transformer(self, transformer, positional_encoding):
# transformer
self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding(
positional_encoding)
self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims
def _prepare_context(self, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
image_embeddings = context['bev_embeddings']
image_embeddings = self.input_proj(
image_embeddings) # only change feature size
# Pass images through encoder
device = image_embeddings.device
# Add 2D coordinate grid embedding
B, C, H, W = image_embeddings.shape
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
sequential_context_embeddings = image_embeddings.reshape(
B, C, H, W)
return (global_context_embedding, sequential_context_embeddings)
def forward(self, context, img_metas=None, multi_scale=False):
'''
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
img_metas
Outs:
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
'''
(global_context_embedding, sequential_context_embeddings) =\
self._prepare_context(context)
x = sequential_context_embeddings
B, C, H, W = x.shape
query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1)
bbox_embed = self.bbox_embedding.weight
query_embedding = query_embedding + bbox_embed[None,None]
query_embedding = query_embedding.view(B, -1, C*2)
img_masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(img_masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs, init_reference, inter_references = self.transformer(
[x,],
[img_masks.type(torch.bool)],
query_embedding,
[pos_embed],
reg_branches= self.reg_branches if self.iterative else None, # noqa:E501
cls_branches= None, # noqa:E501
)
outs_dec = hs.permute(0, 2, 1, 3)
outputs = []
for i, (query_feat) in enumerate(outs_dec):
if i == 0:
reference = init_reference
else:
reference = inter_references[i - 1]
outputs.append(self.get_prediction(i,query_feat,reference))
return outputs
def get_prediction(self, level, query_feat, reference):
bs, num_query, h = query_feat.shape
query_feat = query_feat.view(bs, -1, self.bbox_size,h)
ocls = self.pre_branches['cls'][level](query_feat.flatten(-2))
# ocls = ocls.mean(-2)
reference = inverse_sigmoid(reference)
reference = reference.view(bs, -1, self.bbox_size,self.coord_dim)
tmp = self.pre_branches['reg'][level](query_feat)
tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim]
lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2
lines = lines * self.canvas_size[:self.coord_dim]
lines = lines.flatten(-2)
return dict(
lines=lines, # [bs, num_query, bboxsize*2]
scores=ocls, # [bs, num_query, num_class]
embeddings= query_feat, # [bs, num_query, bbox_size, h]
)
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self,
score_pred,
lines_pred,
gt_labels,
gt_lines,
gt_bboxes_ignore=None):
"""
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
shape [num_query, num_points, 2].
gt_lines (Tensor):
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
shape [num_gt, ]
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_pred_lines = len(lines_pred)
# assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines,
labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets 0: foreground, 1: background
if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long)
else:
labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
if self.discrete_output:
lines_target = torch.zeros_like(lines_pred[..., 0]).long()
lines_weights = torch.zeros_like(lines_pred[..., 0])
else:
lines_target = torch.zeros_like(lines_pred)
lines_weights = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type(
lines_target.dtype)
lines_weights[pos_inds] = 1.0
n = lines_weights.sum(-1, keepdim=True)
lines_weights = lines_weights / n.masked_fill(n == 0, 1)
return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds, pos_gt_inds)
# @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
"""
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- lines_targets_list (list[Tensor]): Lines targets for all \
images.
- lines_weight_list (list[Tensor]): Lines weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
if self.separate_detect:
bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])]
class_label = torch.zeros_like(gts['bbox_mask']).long()
class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])]
else:
class_label = gts['class_label']
bbox = gts['bbox']
if self.discrete_output:
lines_pred = preds['lines'].logits
else:
lines_pred = preds['lines']
bbox = [b.float() for b in bbox]
(labels_list, label_weights_list,
lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply(
self._get_target_single,
preds['scores'], lines_pred,
class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict(
labels=labels_list,
label_weights=label_weights_list,
bboxs=lines_targets_list,
bboxs_weights=lines_weights_list,
)
return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self,
preds: dict,
gts: dict,
gt_bboxes_ignore_list=None,
reduction='none'):
"""
Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
# Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data
for k, v in new_gts.items():
new_gts[k] = torch.stack(v, dim=0)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
# Classification loss
if self.separate_detect:
loss_cls = self.bce_loss(
preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor)
else:
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
cls_labels = new_gts['labels'].reshape(-1)
cls_weights = new_gts['label_weights'].reshape(-1)
loss_cls = self.loss_cls(
cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# position NLL loss
if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos)
else:
loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
loss_dict = dict(
cls=loss_cls,
reg=loss_reg,
)
return loss_dict, pos_inds_list, pos_gt_inds_list
@force_fp32(apply_to=('gt_lines_list', 'preds_dicts'))
def loss(self,
gts: dict,
preds_dicts: dict,
gt_bboxes_ignore=None,
reduction='mean'):
"""
Loss Function.
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
losses, pos_inds_lists, pos_gt_inds_lists = multi_apply(
self.loss_single,
preds_dicts,
gts=gts,
gt_bboxes_ignore_list=gt_bboxes_ignore,
reduction=reduction)
# Format the losses
loss_dict = dict()
# loss from the last decoder layer
for k, v in losses[-1].items():
loss_dict[k] = v
# Loss from other decoder layers
num_dec_layer = 0
for loss in losses[:-1]:
for k, v in loss.items():
loss_dict[f'd{num_dec_layer}.{k}'] = v
num_dec_layer += 1
return loss_dict, pos_inds_lists, pos_gt_inds_lists
def post_process(self, preds_dicts: list, **kwargs):
'''
Args:
preds_dicts:
scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
XXX
'''
preds = preds_dicts[-1]
batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0)
device = batched_cls_scores.device
result_dict = {
'bbox': [],
'scores': [],
'labels': [],
'bbox_flat': [],
'lines_cls': [],
'lines_bs_idx': [],
}
for i in range(batch_size):
cls_scores = batched_cls_scores[i]
det_preds = batched_lines_preds[i]
max_num = self.max_lines
if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid()
scores, valid_idx = cls_scores.view(-1).topk(max_num)
det_labels = valid_idx % self.num_classes
valid_idx = valid_idx // self.num_classes
det_preds = det_preds[valid_idx]
else:
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
scores, valid_idx = scores.topk(max_num)
det_preds = det_preds[valid_idx]
det_labels = det_labels[valid_idx]
nline = len(valid_idx)
result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline)
# for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0)
# quantize the data
result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32)
result_dict['lines_cls'] = torch.cat(
result_dict['labels'], dim=0).long()
result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long()
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear
from mmcv.runner import force_fp32
from torch.distributions.categorical import Categorical
from mmdet.core import (multi_apply, build_assigner, build_sampler,
reduce_mean)
from mmdet.models import HEADS
from .detr_bbox import DETRBboxHead
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import build_loss
from mmcv.cnn import Linear, build_activation_layer, bias_init_with_prob
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
@HEADS.register_module(force=True)
class MapElementDetector(nn.Module):
def __init__(self,
canvas_size=(400, 200),
discrete_output=False,
separate_detect=False,
mode='xyxy',
bbox_size=None,
coord_dim=2,
kp_coord_dim=2,
num_classes=3,
in_channels=128,
num_query=100,
max_lines=50,
score_thre=0.2,
num_reg_fcs=2,
num_points=100,
iterative=False,
patch_size=None,
sync_cls_avg_factor=True,
transformer: dict = None,
positional_encoding: dict = None,
loss_cls: dict = None,
loss_reg: dict = None,
train_cfg: dict = None,):
super().__init__()
assigner = train_cfg['assigner']
self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.train_cfg = train_cfg
self.max_lines = max_lines
self.score_thre = score_thre
self.num_query = num_query
self.in_channels = in_channels
self.num_classes = num_classes
self.num_points = num_points
# branch
# if loss_cls.use_sigmoid:
if loss_cls['use_sigmoid']:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes+1
self.iterative = iterative
self.num_reg_fcs = num_reg_fcs
self._build_transformer(transformer, positional_encoding)
# loss params
self.loss_cls = build_loss(loss_cls)
self.bg_cls_weight = 0.1
if self.loss_cls.use_sigmoid:
self.bg_cls_weight = 0.0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.reg_loss = build_loss(loss_reg)
self.separate_detect = separate_detect
self.discrete_output = discrete_output
self.bbox_size = 3 if mode=='sce' else 2
if bbox_size is not None:
self.bbox_size = bbox_size
self.coord_dim = coord_dim # for xyz
self.kp_coord_dim = kp_coord_dim
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# add reg, cls head for each decoder layer
self._init_layers()
self._init_branch()
self.init_weights()
self._init_embedding()
def _init_layers(self):
"""Initialize some layer."""
self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims)
def _init_embedding(self):
self.label_embed = nn.Embedding(
self.num_classes, self.embed_dims)
self.img_coord_embed = nn.Linear(2, self.embed_dims)
# query_pos_embed & query_embed
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims*2)
# for bbox parameter xstart, ystart, xend, yend
self.bbox_embedding = nn.Embedding( self.bbox_size,
self.embed_dims*2)
def _init_branch(self,):
"""Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims*self.bbox_size, self.cls_out_channels)
# fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.LayerNorm(self.embed_dims))
reg_branch.append(nn.ReLU())
if self.discrete_output:
reg_branch.append(nn.Linear(
self.embed_dims, max(self.canvas_size), bias=True,))
else:
reg_branch.append(nn.Linear(
self.embed_dims, self.coord_dim, bias=True,))
reg_branch = nn.Sequential(*reg_branch)
# add sigmoid or not
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
num_pred = self.transformer.decoder.num_layers
if self.iterative:
fc_cls = _get_clones(fc_cls, num_pred)
reg_branch = _get_clones(reg_branch, num_pred)
else:
reg_branch = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
fc_cls = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.pre_branches = nn.ModuleDict([
('cls', fc_cls),
('reg', reg_branch), ])
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
for p in self.input_proj.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.transformer.init_weights()
# init prediction branch
for k, v in self.pre_branches.items():
for param in v.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
# focal loss init
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
# for last layer
if isinstance(self.pre_branches['cls'], nn.ModuleList):
for m in self.pre_branches['cls']:
nn.init.constant_(m.bias, bias_init)
else:
m = self.pre_branches['cls']
nn.init.constant_(m.bias, bias_init)
def _build_transformer(self, transformer, positional_encoding):
# transformer
self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding(
positional_encoding)
self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims
def _prepare_context(self, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
image_embeddings = context['bev_embeddings']
image_embeddings = self.input_proj(
image_embeddings) # only change feature size
# Pass images through encoder
device = image_embeddings.device
# Add 2D coordinate grid embedding
B, C, H, W = image_embeddings.shape
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
sequential_context_embeddings = image_embeddings.reshape(
B, C, H, W)
return (global_context_embedding, sequential_context_embeddings)
def forward(self, context, img_metas=None, multi_scale=False):
'''
Args:
bev_feature (List[Tensor]): shape [B, C, H, W]
feature in bev view
img_metas
Outs:
preds_dict (Dict):
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
'''
(global_context_embedding, sequential_context_embeddings) =\
self._prepare_context(context)
x = sequential_context_embeddings
B, C, H, W = x.shape
query_embedding = self.query_embedding.weight[None,:,None].repeat(B, 1, self.bbox_size, 1)
bbox_embed = self.bbox_embedding.weight
query_embedding = query_embedding + bbox_embed[None,None]
query_embedding = query_embedding.view(B, -1, C*2)
img_masks = x.new_zeros((B, H, W))
pos_embed = self.positional_encoding(img_masks)
# outs_dec: [nb_dec, bs, num_query, embed_dim]
hs, init_reference, inter_references = self.transformer(
[x,],
[img_masks.type(torch.bool)],
query_embedding,
[pos_embed],
reg_branches= self.reg_branches if self.iterative else None, # noqa:E501
cls_branches= None, # noqa:E501
)
outs_dec = hs.permute(0, 2, 1, 3)
outputs = []
for i, (query_feat) in enumerate(outs_dec):
if i == 0:
reference = init_reference
else:
reference = inter_references[i - 1]
outputs.append(self.get_prediction(i,query_feat,reference))
return outputs
def get_prediction(self, level, query_feat, reference):
bs, num_query, h = query_feat.shape
query_feat = query_feat.view(bs, -1, self.bbox_size,h)
ocls = self.pre_branches['cls'][level](query_feat.flatten(-2))
# ocls = ocls.mean(-2)
reference = inverse_sigmoid(reference)
reference = reference.view(bs, -1, self.bbox_size,self.coord_dim)
tmp = self.pre_branches['reg'][level](query_feat)
tmp[...,:self.kp_coord_dim] = tmp[...,:self.kp_coord_dim] + reference[...,:self.kp_coord_dim]
lines = tmp.sigmoid() # bs, num_query, self.bbox_size,2
lines = lines * self.canvas_size[:self.coord_dim]
lines = lines.flatten(-2)
return dict(
lines=lines, # [bs, num_query, bboxsize*2]
scores=ocls, # [bs, num_query, num_class]
embeddings= query_feat, # [bs, num_query, bbox_size, h]
)
@force_fp32(apply_to=('score_pred', 'lines_pred', 'gt_lines'))
def _get_target_single(self,
score_pred,
lines_pred,
gt_labels,
gt_lines,
gt_bboxes_ignore=None):
"""
Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
lines_pred (Tensor):
shape [num_query, num_points, 2].
gt_lines (Tensor):
shape [num_gt, num_points, 2].
gt_labels (torch.LongTensor)
shape [num_gt, ]
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (LongTensor): Labels of each image.
shape [num_query, 1]
- label_weights (Tensor]): Label weights of each image.
shape [num_query, 1]
- lines_target (Tensor): Lines targets of each image.
shape [num_query, num_points, 2]
- lines_weights (Tensor): Lines weights of each image.
shape [num_query, num_points, 2]
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_pred_lines = len(lines_pred)
# assigner and sampler
assign_result = self.assigner.assign(preds=dict(lines=lines_pred, scores=score_pred,),
gts=dict(lines=gt_lines,
labels=gt_labels, ),
gt_bboxes_ignore=gt_bboxes_ignore)
sampling_result = self.sampler.sample(
assign_result, lines_pred, gt_lines)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets 0: foreground, 1: background
if self.separate_detect:
labels = gt_lines.new_full((num_pred_lines, ), 1, dtype=torch.long)
else:
labels = gt_lines.new_full(
(num_pred_lines, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_lines.new_ones(num_pred_lines)
# bbox targets since lines_pred's last dimension is the vocabulary
# and ground truth dose not have this dimension.
if self.discrete_output:
lines_target = torch.zeros_like(lines_pred[..., 0]).long()
lines_weights = torch.zeros_like(lines_pred[..., 0])
else:
lines_target = torch.zeros_like(lines_pred)
lines_weights = torch.zeros_like(lines_pred)
lines_target[pos_inds] = sampling_result.pos_gt_bboxes.type(
lines_target.dtype)
lines_weights[pos_inds] = 1.0
n = lines_weights.sum(-1, keepdim=True)
lines_weights = lines_weights / n.masked_fill(n == 0, 1)
return (labels, label_weights, lines_target, lines_weights,
pos_inds, neg_inds, pos_gt_inds)
# @force_fp32(apply_to=('preds', 'gts'))
def get_targets(self, preds, gts, gt_bboxes_ignore_list=None):
"""
Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
lines_preds_list (list[Tensor]): [num_query, num_points, 2].
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- lines_targets_list (list[Tensor]): Lines targets for all \
images.
- lines_weight_list (list[Tensor]): Lines weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
# format the inputs
if self.separate_detect:
bbox = [b[m] for b, m in zip(gts['bbox'], gts['bbox_mask'])]
class_label = torch.zeros_like(gts['bbox_mask']).long()
class_label = [b[m] for b, m in zip(class_label, gts['bbox_mask'])]
else:
class_label = gts['class_label']
bbox = gts['bbox']
if self.discrete_output:
lines_pred = preds['lines'].logits
else:
lines_pred = preds['lines']
bbox = [b.float() for b in bbox]
(labels_list, label_weights_list,
lines_targets_list, lines_weights_list,
pos_inds_list, neg_inds_list,pos_gt_inds_list) = multi_apply(
self._get_target_single,
preds['scores'], lines_pred,
class_label, bbox,
gt_bboxes_ignore=gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
new_gts = dict(
labels=labels_list,
label_weights=label_weights_list,
bboxs=lines_targets_list,
bboxs_weights=lines_weights_list,
)
return new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list
# @force_fp32(apply_to=('preds', 'gts'))
def loss_single(self,
preds: dict,
gts: dict,
gt_bboxes_ignore_list=None,
reduction='none'):
"""
Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lines_preds (Tensor):
shape [bs, num_query, num_points, 2].
gt_lines_list (list[Tensor]):
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
# Get target for each sample
new_gts, num_total_pos, num_total_neg, pos_inds_list, pos_gt_inds_list =\
self.get_targets(preds, gts, gt_bboxes_ignore_list)
# Batched all data
for k, v in new_gts.items():
new_gts[k] = torch.stack(v, dim=0)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
preds['scores'].new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
# Classification loss
if self.separate_detect:
loss_cls = self.bce_loss(
preds['scores'], new_gts['labels'], new_gts['label_weights'], cls_avg_factor)
else:
# since the inputs needs the second dim is the class dim, we permute the prediction.
cls_scores = preds['scores'].reshape(-1, self.cls_out_channels)
cls_labels = new_gts['labels'].reshape(-1)
cls_weights = new_gts['label_weights'].reshape(-1)
loss_cls = self.loss_cls(
cls_scores, cls_labels, cls_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# position NLL loss
if self.discrete_output:
loss_reg = -(preds['lines'].log_prob(new_gts['bboxs']) *
new_gts['bboxs_weights']).sum()/(num_total_pos)
else:
loss_reg = self.reg_loss(
preds['lines'], new_gts['bboxs'], new_gts['bboxs_weights'], avg_factor=num_total_pos)
loss_dict = dict(
cls=loss_cls,
reg=loss_reg,
)
return loss_dict, pos_inds_list, pos_gt_inds_list
@force_fp32(apply_to=('gt_lines_list', 'preds_dicts'))
def loss(self,
gts: dict,
preds_dicts: dict,
gt_bboxes_ignore=None,
reduction='mean'):
"""
Loss Function.
Args:
gt_lines_list (list[Tensor]): Ground truth lines for each image
with shape (num_gts, num_points, 2)
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lines_preds (Tensor):
[nb_dec, bs, num_query, num_points, 2].
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
# Since there might have multi layer
losses, pos_inds_lists, pos_gt_inds_lists = multi_apply(
self.loss_single,
preds_dicts,
gts=gts,
gt_bboxes_ignore_list=gt_bboxes_ignore,
reduction=reduction)
# Format the losses
loss_dict = dict()
# loss from the last decoder layer
for k, v in losses[-1].items():
loss_dict[k] = v
# Loss from other decoder layers
num_dec_layer = 0
for loss in losses[:-1]:
for k, v in loss.items():
loss_dict[f'd{num_dec_layer}.{k}'] = v
num_dec_layer += 1
return loss_dict, pos_inds_lists, pos_gt_inds_lists
def post_process(self, preds_dicts: list, **kwargs):
'''
Args:
preds_dicts:
scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
lines (Tensor):
[nb_dec, bs, num_query, bbox parameters(4)].
Outs:
ret_list (List[Dict]) with length as bs
list of result dict for each sample in the batch
XXX
'''
preds = preds_dicts[-1]
batched_cls_scores = preds['scores']
batched_lines_preds = preds['lines']
batch_size = batched_cls_scores.size(0)
device = batched_cls_scores.device
result_dict = {
'bbox': [],
'scores': [],
'labels': [],
'bbox_flat': [],
'lines_cls': [],
'lines_bs_idx': [],
}
for i in range(batch_size):
cls_scores = batched_cls_scores[i]
det_preds = batched_lines_preds[i]
max_num = self.max_lines
if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid()
scores, valid_idx = cls_scores.view(-1).topk(max_num)
det_labels = valid_idx % self.num_classes
valid_idx = valid_idx // self.num_classes
det_preds = det_preds[valid_idx]
else:
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
scores, valid_idx = scores.topk(max_num)
det_preds = det_preds[valid_idx]
det_labels = det_labels[valid_idx]
nline = len(valid_idx)
result_dict['bbox'].append(det_preds)
result_dict['scores'].append(scores)
result_dict['labels'].append(det_labels)
result_dict['lines_bs_idx'].extend([i]*nline)
# for down stream polyline
_bboxs = torch.cat(result_dict['bbox'], dim=0)
# quantize the data
result_dict['bbox_flat'] = torch.round(_bboxs).type(torch.int32)
result_dict['lines_cls'] = torch.cat(
result_dict['labels'], dim=0).long()
result_dict['lines_bs_idx'] = torch.tensor(
result_dict['lines_bs_idx'], device=device).long()
return result_dict
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from mmdet.models import HEADS
from .detgen_utils.causal_trans import (CausalTransformerDecoder,
CausalTransformerDecoderLayer)
from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask,
quantize_verts, top_k_logits, top_p_logits)
from mmcv.runner import force_fp32, auto_fp16
@HEADS.register_module(force=True)
class PolylineGenerator(nn.Module):
"""
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def __init__(self,
in_channels,
encoder_config,
decoder_config,
class_conditional=True,
num_classes=55,
decoder_cross_attention=True,
use_discrete_vertex_embeddings=True,
condition_points_num=3,
coord_dim=2,
canvas_size=(400, 200),
max_seq_length=500,
name='gen_model'):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super(PolylineGenerator, self).__init__()
self.embedding_dim = decoder_config['layer_config']['d_model']
self.class_conditional = class_conditional
self.num_classes = num_classes
self.max_seq_length = max_seq_length
self.decoder_cross_attention = decoder_cross_attention
self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings
self.condition_points_num = condition_points_num
self.fp16_enabled = False
self.coord_dim = coord_dim # if we use xyz else 2 when we use xy
self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# initialize the model
self._project_to_logits = nn.Linear(
self.embedding_dim,
max(canvas_size) + 1, # + 1 for stopping token. use_bias=True,
)
self.input_proj = nn.Conv2d(
in_channels, self.embedding_dim, kernel_size=1)
decoder_layer = CausalTransformerDecoderLayer(
**decoder_config.pop('layer_config'))
self.decoder = CausalTransformerDecoder(
decoder_layer, **decoder_config)
self._init_embedding()
self.init_weights()
def _init_embedding(self):
if self.class_conditional:
self.label_embed = nn.Embedding(
self.num_classes, self.embedding_dim)
self.coord_embed = nn.Embedding(self.coord_dim, self.embedding_dim)
self.pos_embeddings = nn.Embedding(
self.max_seq_length, self.embedding_dim)
# to indicate the role of the position is the start of the line or the end of it.
self.bbox_context_embed = \
nn.Embedding(self.condition_points_num, self.embedding_dim)
self.img_coord_embed = nn.Linear(2, self.embedding_dim)
# initialize the verteices embedding
if self.use_discrete_vertex_embeddings:
self.vertex_embed = nn.Embedding(
max(self.canvas_size) + 1, self.embedding_dim)
else:
self.vertex_embed = nn.Linear(1, self.embedding_dim)
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _embed_kps(self, bbox):
bbox_len = bbox.shape[-1]
# Bbox_context
bbox_embedding = self.bbox_context_embed(
(torch.arange(bbox_len, device=bbox.device) / self.kp_coord_dim).floor().long())
# Coord indicators (x, y)
coord_embeddings = self.coord_embed(
torch.arange(bbox_len, device=bbox.device) % self.kp_coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(bbox)
return vert_embeddings + (bbox_embedding+coord_embeddings)[None]
def _prepare_context(self, batch, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
if self.class_conditional:
global_context_embedding = self.label_embed(batch['lines_cls'])
bbox_embeddings = self._embed_kps(batch['bbox_flat'])
if global_context_embedding is not None:
global_context_embedding = torch.cat(
[global_context_embedding[:, None], bbox_embeddings], dim=1)
# Pass images through encoder
image_embeddings = assign_bev(
context['bev_embeddings'], batch['lines_bs_idx'])
image_embeddings = self.input_proj(image_embeddings)
device = image_embeddings.device
# Add 2D coordinate grid embedding
H, W = image_embeddings.shape[2:]
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
B = image_embeddings.shape[0]
sequential_context_embeddings = image_embeddings.reshape(
B, self.embedding_dim, -1).permute(0, 2, 1)
return (global_context_embedding,
sequential_context_embeddings)
def _embed_inputs(self, seqs, condition_embedding=None):
"""Embeds face sequences and adds within and between face positions.
Args:
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
embeddings: B, seqlen, h
"""
B, seq_len = seqs.shape[:2]
# Position embeddings
pos_embeddings = self.pos_embeddings(
(torch.arange(seq_len, device=seqs.device) / self.coord_dim).floor().long()) # seq_len, h
# Coord indicators (x, y, z(optional))
coord_embeddings = self.coord_embed(
torch.arange(seq_len, device=seqs.device) % self.coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(seqs)
# Aggregate embeddings
embeddings = vert_embeddings + \
(coord_embeddings+pos_embeddings)[None]
embeddings = torch.cat([condition_embedding, embeddings], dim=1)
return embeddings
def forward(self, batch: dict, **kwargs):
"""
Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def sperate_forward(self, batch, context, **kwargs):
polyline_length = batch['polyline_masks'].sum(-1)
c1, c2, revert_idx, size = get_chunk_idx(polyline_length)
sizes = [size, polyline_length.max()]
polyline_logits = []
for c_idx, size in zip([c1,c2], sizes):
new_batch = assign_batch(batch,c_idx, size)
_poly_logits = self._forward_train(new_batch,context,**kwargs)
polyline_logits.append(_poly_logits)
# maybe imporve the speed
for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)):
if size < sizes[1]:
_poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0)
polyline_logits[i] = _poly_logits
polyline_logits = torch.cat(polyline_logits,0)
polyline_logits = polyline_logits[revert_idx]
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
def forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
if False:
polyline_logits = self._forward_train(batch, context, **kwargs)
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
else:
return self.sperate_forward(batch, context, **kwargs)
def _forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
global_context, seq_context = self._prepare_context(
batch, context)
logits = self.body(
# Last element not used for preds
batch['polylines'][:, :-1],
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
return_logits=True,
is_training=self.training)
return logits
@force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache'))
def body(self,
seqs,
global_context_embedding=None,
sequential_context_embeddings=None,
temperature=1.,
top_k=0,
top_p=1.,
cache=None,
return_logits=False,
is_training=True):
"""
Outputs categorical dist for vertex indices.
Body of the face model
"""
# Embed inputs
condition_len = global_context_embedding.shape[1]
decoder_inputs = self._embed_inputs(
seqs, global_context_embedding)
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
decoder_inputs = decoder_inputs.transpose(0, 1)
if sequential_context_embeddings is not None:
sequential_context_embeddings = sequential_context_embeddings.transpose(
0, 1)
causal_msk = None
if is_training:
causal_msk = generate_square_subsequent_mask(
decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device)
decoder_outputs, cache = self.decoder(
tgt=decoder_inputs,
cache=cache,
memory=sequential_context_embeddings,
causal_mask=causal_msk,
)
decoder_outputs = decoder_outputs.transpose(0, 1)
# since we only need the predict seq
decoder_outputs = decoder_outputs[:, condition_len-1:]
# Get logits and optionally process for sampling
logits = self._project_to_logits(decoder_outputs)
# y mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
vertices_mask_y[0] = False # y position doesn't have stop sign
logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \
vertices_mask_y - ~vertices_mask_y*1e9
if self.coord_dim > 2:
# z mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_z = (_vert_mask < self.canvas_size[2]+1)
vertices_mask_z[0] = False # y position doesn't have stop sign
logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \
vertices_mask_z - ~vertices_mask_z*1e9
logits = logits/temperature
logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p)
if return_logits:
return logits
cat_dist = Categorical(logits=logits)
return cat_dist, cache
@force_fp32(apply_to=('pred'))
def loss(self, gt: dict, pred: dict):
weight = gt['polyline_weights']
mask = gt['polyline_masks']
loss = -torch.sum(
pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum()
return {'seq': loss}
def inference(self,
batch: dict,
context: dict,
max_sample_length=None,
temperature=1.,
top_k=0,
top_p=1.,
only_return_complete=False,
gt_condition=False,
**kwargs):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
# prepare the conditional variable
global_context, seq_context = self._prepare_context(
batch, context)
device = global_context.device
batch_size = global_context.shape[0]
# While loop sampling with caching
samples = torch.empty(
[batch_size, 0], dtype=torch.int32, device=device)
max_sample_length = max_sample_length or self.max_seq_length
seq_len = max_sample_length*self.coord_dim+1
cache = None
decoded_tokens = \
torch.zeros((batch_size,seq_len),
device=device,dtype=torch.long)
remain_idx = torch.arange(batch_size, device=device)
for i in range(seq_len):
# While-loop body for autoregression calculation.
pred_dist, cache = self.body(
samples,
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
cache=cache,
temperature=temperature,
top_k=top_k,
top_p=top_p,
is_training=False)
samples = pred_dist.sample()
decoded_tokens[remain_idx,i] = samples[:,-1]
# Stopping conditions for autoregressive calculation.
if not (decoded_tokens[:,:i+1] != 0).all(-1).any():
break
# update state, check the new position is zero.
valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0]
remain_idx = remain_idx[valid_idx]
cache = cache[:,:,valid_idx]
global_context = global_context[valid_idx]
seq_context = seq_context[valid_idx]
samples = samples[valid_idx]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens = decoded_tokens[:,:i+1]
outputs = self.post_process(decoded_tokens, seq_len,
device, only_return_complete)
return outputs
def post_process(self, polyline,
max_seq_len=None,
device=None,
only_return_complete=True):
'''
format the predictions
find the mask
'''
# Record completed samples
complete_samples = (polyline == 0).any(-1)
# Find number of faces
sample_seq_length = polyline.shape[-1]
_polyline_mask = torch.arange(sample_seq_length)[None].to(device)
# Get largest stopping point for incomplete samples.
valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length)
zero_inds = (polyline == 0).type(torch.int32).argmax(-1)
# Real length
valid_polyline_len[complete_samples] = zero_inds[complete_samples] + 1
polyline_mask = _polyline_mask < valid_polyline_len[:, None]
# Mask faces beyond stopping token with zeros
polyline = polyline*polyline_mask
# Pad to maximum size with zeros
pad_size = max_seq_len - sample_seq_length
polyline = F.pad(polyline, (0, pad_size))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# if only_return_complete:
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# outputs
outputs = {
'completed': complete_samples,
'polylines': polyline,
'polyline_masks': polyline_mask,
}
return outputs
def find_best_sperate_plan(idx,array):
h = array[-1] - array[idx]
w = idx
cost = h*w
return cost
def get_chunk_idx(polyline_length):
_polyline_length, polyline_length_idx = torch.sort(polyline_length)
costs = []
for i in range(len(_polyline_length)):
cost = find_best_sperate_plan(i,_polyline_length)
costs.append(cost)
seperate_point = torch.stack(costs).argmax()
chunk1 = polyline_length_idx[:seperate_point+1]
chunk2 = polyline_length_idx[seperate_point+1:]
revert_idx = torch.argsort(polyline_length_idx)
return chunk1, chunk2, revert_idx, _polyline_length[seperate_point]
def assign_bev(feat, idx):
return feat[idx]
def assign_batch(batch, idx, size):
new_batch = {}
for k,v in batch.items():
new_batch[k] = v[idx]
if new_batch[k].ndim > 1:
new_batch[k] = new_batch[k][:,:size]
return new_batch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from mmdet.models import HEADS
from .detgen_utils.causal_trans import (CausalTransformerDecoder,
CausalTransformerDecoderLayer)
from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask,
quantize_verts, top_k_logits, top_p_logits)
from mmcv.runner import force_fp32, auto_fp16
@HEADS.register_module(force=True)
class PolylineGenerator(nn.Module):
"""
Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def __init__(self,
in_channels,
encoder_config,
decoder_config,
class_conditional=True,
num_classes=55,
decoder_cross_attention=True,
use_discrete_vertex_embeddings=True,
condition_points_num=3,
coord_dim=2,
canvas_size=(400, 200),
max_seq_length=500,
name='gen_model'):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super(PolylineGenerator, self).__init__()
self.embedding_dim = decoder_config['layer_config']['d_model']
self.class_conditional = class_conditional
self.num_classes = num_classes
self.max_seq_length = max_seq_length
self.decoder_cross_attention = decoder_cross_attention
self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings
self.condition_points_num = condition_points_num
self.fp16_enabled = False
self.coord_dim = coord_dim # if we use xyz else 2 when we use xy
self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX
self.register_buffer('canvas_size', torch.tensor(canvas_size))
# initialize the model
self._project_to_logits = nn.Linear(
self.embedding_dim,
max(canvas_size) + 1, # + 1 for stopping token. use_bias=True,
)
self.input_proj = nn.Conv2d(
in_channels, self.embedding_dim, kernel_size=1)
decoder_layer = CausalTransformerDecoderLayer(
**decoder_config.pop('layer_config'))
self.decoder = CausalTransformerDecoder(
decoder_layer, **decoder_config)
self._init_embedding()
self.init_weights()
def _init_embedding(self):
if self.class_conditional:
self.label_embed = nn.Embedding(
self.num_classes, self.embedding_dim)
self.coord_embed = nn.Embedding(self.coord_dim, self.embedding_dim)
self.pos_embeddings = nn.Embedding(
self.max_seq_length, self.embedding_dim)
# to indicate the role of the position is the start of the line or the end of it.
self.bbox_context_embed = \
nn.Embedding(self.condition_points_num, self.embedding_dim)
self.img_coord_embed = nn.Linear(2, self.embedding_dim)
# initialize the verteices embedding
if self.use_discrete_vertex_embeddings:
self.vertex_embed = nn.Embedding(
max(self.canvas_size) + 1, self.embedding_dim)
else:
self.vertex_embed = nn.Linear(1, self.embedding_dim)
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _embed_kps(self, bbox):
bbox_len = bbox.shape[-1]
# Bbox_context
bbox_embedding = self.bbox_context_embed(
(torch.arange(bbox_len, device=bbox.device) / self.kp_coord_dim).floor().long())
# Coord indicators (x, y)
coord_embeddings = self.coord_embed(
torch.arange(bbox_len, device=bbox.device) % self.kp_coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(bbox)
return vert_embeddings + (bbox_embedding+coord_embeddings)[None]
def _prepare_context(self, batch, context):
"""Prepare class label and vertex context."""
global_context_embedding = None
if self.class_conditional:
global_context_embedding = self.label_embed(batch['lines_cls'])
bbox_embeddings = self._embed_kps(batch['bbox_flat'])
if global_context_embedding is not None:
global_context_embedding = torch.cat(
[global_context_embedding[:, None], bbox_embeddings], dim=1)
# Pass images through encoder
image_embeddings = assign_bev(
context['bev_embeddings'], batch['lines_bs_idx'])
image_embeddings = self.input_proj(image_embeddings)
device = image_embeddings.device
# Add 2D coordinate grid embedding
H, W = image_embeddings.shape[2:]
Ws = torch.linspace(-1., 1., W)
Hs = torch.linspace(-1., 1., H)
image_coords = torch.stack(
torch.meshgrid(Hs, Ws), dim=-1).to(device)
image_coord_embeddings = self.img_coord_embed(image_coords)
image_embeddings += image_coord_embeddings[None].permute(0, 3, 1, 2)
# Reshape spatial grid to sequence
B = image_embeddings.shape[0]
sequential_context_embeddings = image_embeddings.reshape(
B, self.embedding_dim, -1).permute(0, 2, 1)
return (global_context_embedding,
sequential_context_embeddings)
def _embed_inputs(self, seqs, condition_embedding=None):
"""Embeds face sequences and adds within and between face positions.
Args:
seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns:
embeddings: B, seqlen, h
"""
B, seq_len = seqs.shape[:2]
# Position embeddings
pos_embeddings = self.pos_embeddings(
(torch.arange(seq_len, device=seqs.device) / self.coord_dim).floor().long()) # seq_len, h
# Coord indicators (x, y, z(optional))
coord_embeddings = self.coord_embed(
torch.arange(seq_len, device=seqs.device) % self.coord_dim)
# Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(seqs)
# Aggregate embeddings
embeddings = vert_embeddings + \
(coord_embeddings+pos_embeddings)[None]
embeddings = torch.cat([condition_embedding, embeddings], dim=1)
return embeddings
def forward(self, batch: dict, **kwargs):
"""
Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
"""
if self.training:
return self.forward_train(batch, **kwargs)
else:
return self.inference(batch, **kwargs)
def sperate_forward(self, batch, context, **kwargs):
polyline_length = batch['polyline_masks'].sum(-1)
c1, c2, revert_idx, size = get_chunk_idx(polyline_length)
sizes = [size, polyline_length.max()]
polyline_logits = []
for c_idx, size in zip([c1,c2], sizes):
new_batch = assign_batch(batch,c_idx, size)
_poly_logits = self._forward_train(new_batch,context,**kwargs)
polyline_logits.append(_poly_logits)
# maybe imporve the speed
for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)):
if size < sizes[1]:
_poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0)
polyline_logits[i] = _poly_logits
polyline_logits = torch.cat(polyline_logits,0)
polyline_logits = polyline_logits[revert_idx]
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
def forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
if False:
polyline_logits = self._forward_train(batch, context, **kwargs)
cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist}
else:
return self.sperate_forward(batch, context, **kwargs)
def _forward_train(self, batch: dict, context: dict, **kwargs):
"""
Returns:
pred_dist: Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
# we use the gt vertices
global_context, seq_context = self._prepare_context(
batch, context)
logits = self.body(
# Last element not used for preds
batch['polylines'][:, :-1],
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
return_logits=True,
is_training=self.training)
return logits
@force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache'))
def body(self,
seqs,
global_context_embedding=None,
sequential_context_embeddings=None,
temperature=1.,
top_k=0,
top_p=1.,
cache=None,
return_logits=False,
is_training=True):
"""
Outputs categorical dist for vertex indices.
Body of the face model
"""
# Embed inputs
condition_len = global_context_embedding.shape[1]
decoder_inputs = self._embed_inputs(
seqs, global_context_embedding)
# Pass through Transformer decoder
# since our memory efficient decoder only support seq first setting.
decoder_inputs = decoder_inputs.transpose(0, 1)
if sequential_context_embeddings is not None:
sequential_context_embeddings = sequential_context_embeddings.transpose(
0, 1)
causal_msk = None
if is_training:
causal_msk = generate_square_subsequent_mask(
decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device)
decoder_outputs, cache = self.decoder(
tgt=decoder_inputs,
cache=cache,
memory=sequential_context_embeddings,
causal_mask=causal_msk,
)
decoder_outputs = decoder_outputs.transpose(0, 1)
# since we only need the predict seq
decoder_outputs = decoder_outputs[:, condition_len-1:]
# Get logits and optionally process for sampling
logits = self._project_to_logits(decoder_outputs)
# y mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_y = (_vert_mask < self.canvas_size[1]+1)
vertices_mask_y[0] = False # y position doesn't have stop sign
logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \
vertices_mask_y - ~vertices_mask_y*1e9
if self.coord_dim > 2:
# z mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_z = (_vert_mask < self.canvas_size[2]+1)
vertices_mask_z[0] = False # y position doesn't have stop sign
logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \
vertices_mask_z - ~vertices_mask_z*1e9
logits = logits/temperature
logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p)
if return_logits:
return logits
cat_dist = Categorical(logits=logits)
return cat_dist, cache
@force_fp32(apply_to=('pred'))
def loss(self, gt: dict, pred: dict):
weight = gt['polyline_weights']
mask = gt['polyline_masks']
loss = -torch.sum(
pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum()
return {'seq': loss}
def inference(self,
batch: dict,
context: dict,
max_sample_length=None,
temperature=1.,
top_k=0,
top_p=1.,
only_return_complete=False,
gt_condition=False,
**kwargs):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'valid_polyline_len': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
# prepare the conditional variable
global_context, seq_context = self._prepare_context(
batch, context)
device = global_context.device
batch_size = global_context.shape[0]
# While loop sampling with caching
samples = torch.empty(
[batch_size, 0], dtype=torch.int32, device=device)
max_sample_length = max_sample_length or self.max_seq_length
seq_len = max_sample_length*self.coord_dim+1
cache = None
decoded_tokens = \
torch.zeros((batch_size,seq_len),
device=device,dtype=torch.long)
remain_idx = torch.arange(batch_size, device=device)
for i in range(seq_len):
# While-loop body for autoregression calculation.
pred_dist, cache = self.body(
samples,
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
cache=cache,
temperature=temperature,
top_k=top_k,
top_p=top_p,
is_training=False)
samples = pred_dist.sample()
decoded_tokens[remain_idx,i] = samples[:,-1]
# Stopping conditions for autoregressive calculation.
if not (decoded_tokens[:,:i+1] != 0).all(-1).any():
break
# update state, check the new position is zero.
valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0]
remain_idx = remain_idx[valid_idx]
cache = cache[:,:,valid_idx]
global_context = global_context[valid_idx]
seq_context = seq_context[valid_idx]
samples = samples[valid_idx]
# decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens = decoded_tokens[:,:i+1]
outputs = self.post_process(decoded_tokens, seq_len,
device, only_return_complete)
return outputs
def post_process(self, polyline,
max_seq_len=None,
device=None,
only_return_complete=True):
'''
format the predictions
find the mask
'''
# Record completed samples
complete_samples = (polyline == 0).any(-1)
# Find number of faces
sample_seq_length = polyline.shape[-1]
_polyline_mask = torch.arange(sample_seq_length)[None].to(device)
# Get largest stopping point for incomplete samples.
valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length)
zero_inds = (polyline == 0).type(torch.int32).argmax(-1)
# Real length
valid_polyline_len[complete_samples] = zero_inds[complete_samples] + 1
polyline_mask = _polyline_mask < valid_polyline_len[:, None]
# Mask faces beyond stopping token with zeros
polyline = polyline*polyline_mask
# Pad to maximum size with zeros
pad_size = max_seq_len - sample_seq_length
polyline = F.pad(polyline, (0, pad_size))
# polyline_mask = F.pad(polyline_mask, (0, pad_size))
# XXX
# if only_return_complete:
# polyline = polyline[complete_samples]
# valid_polyline_len = valid_polyline_len[complete_samples]
# context = tf.nest.map_structure(
# lambda x: tf.boolean_mask(x, complete_samples), context)
# complete_samples = complete_samples[complete_samples]
# outputs
outputs = {
'completed': complete_samples,
'polylines': polyline,
'polyline_masks': polyline_mask,
}
return outputs
def find_best_sperate_plan(idx,array):
h = array[-1] - array[idx]
w = idx
cost = h*w
return cost
def get_chunk_idx(polyline_length):
_polyline_length, polyline_length_idx = torch.sort(polyline_length)
costs = []
for i in range(len(_polyline_length)):
cost = find_best_sperate_plan(i,_polyline_length)
costs.append(cost)
seperate_point = torch.stack(costs).argmax()
chunk1 = polyline_length_idx[:seperate_point+1]
chunk2 = polyline_length_idx[seperate_point+1:]
revert_idx = torch.argsort(polyline_length_idx)
return chunk1, chunk2, revert_idx, _polyline_length[seperate_point]
def assign_bev(feat, idx):
return feat[idx]
def assign_batch(batch, idx, size):
new_batch = {}
for k,v in batch.items():
new_batch[k] = v[idx]
if new_batch[k].ndim > 1:
new_batch[k] = new_batch[k][:,:size]
return new_batch
import torch
from torch import nn as nn
from torch.nn import functional as F
from mmdet.models.losses import l1_loss
from mmdet.models.losses.utils import weighted_loss
import mmcv
from mmdet.models.builder import LOSSES
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert beta > 0
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return loss
@LOSSES.register_module()
class LinesLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super(LinesLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.beta = beta
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def bce(pred, label, class_weight=None):
"""
pred: B,nquery,npts
label: B,nquery,npts
"""
if label.numel() == 0:
return pred.sum() * 0
assert pred.size() == label.size()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class MasksLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MasksLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = bce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def ce(pred, label, class_weight=None):
"""
pred: B*nquery,npts
label: B*nquery,
"""
if label.numel() == 0:
return pred.sum() * 0
loss = F.cross_entropy(
pred, label, weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class LenLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(LenLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = ce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
import torch
from torch import nn as nn
from torch.nn import functional as F
from mmdet.models.losses import l1_loss
from mmdet.models.losses.utils import weighted_loss
import mmcv
from mmdet.models.builder import LOSSES
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert beta > 0
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return loss
@LOSSES.register_module()
class LinesLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
"""
L1 loss. The same as the smooth L1 loss
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
super(LinesLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.beta = beta
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
shape: [bs, ...]
target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def bce(pred, label, class_weight=None):
"""
pred: B,nquery,npts
label: B,nquery,npts
"""
if label.numel() == 0:
return pred.sum() * 0
assert pred.size() == label.size()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class MasksLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MasksLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = bce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
return loss*self.loss_weight
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def ce(pred, label, class_weight=None):
"""
pred: B*nquery,npts
label: B*nquery,
"""
if label.numel() == 0:
return pred.sum() * 0
loss = F.cross_entropy(
pred, label, weight=class_weight, reduction='none')
return loss
@LOSSES.register_module()
class LenLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(LenLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
xxx
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = ce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor)
return loss*self.loss_weight
\ No newline at end of file
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log
from mmdet.utils import get_root_logger
from mmdet3d.models.builder import DETECTORS
MAPPERS = DETECTORS
class BaseMapper(nn.Module, metaclass=ABCMeta):
"""Base class for mappers."""
def __init__(self):
super(BaseMapper, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None))
#@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
def forward_train(self, *args, **kwargs):
pass
#@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
#@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger)
def forward_test(self, *args, **kwargs):
"""
Args:
"""
if True:
self.simple_test()
else:
self.aug_test()
# @auto_fp16(apply_to=('img', ))
def forward(self, *args, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(*args, **kwargs)
else:
kwargs.pop('rescale')
return self.forward_test(*args, **kwargs)
def train_step(self, data_dict, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
loss, log_vars, num_samples = self(**data_dict)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
loss, log_vars, num_samples = self(**data)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def show_result(self,
**kwargs):
img = None
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log
from mmdet.utils import get_root_logger
from mmdet3d.models.builder import DETECTORS
MAPPERS = DETECTORS
class BaseMapper(nn.Module, metaclass=ABCMeta):
"""Base class for mappers."""
def __init__(self):
super(BaseMapper, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None))
#@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
def forward_train(self, *args, **kwargs):
pass
#@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
#@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger)
def forward_test(self, *args, **kwargs):
"""
Args:
"""
if True:
self.simple_test()
else:
self.aug_test()
# @auto_fp16(apply_to=('img', ))
def forward(self, *args, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(*args, **kwargs)
else:
kwargs.pop('rescale')
return self.forward_test(*args, **kwargs)
def train_step(self, data_dict, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_dict (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
loss, log_vars, num_samples = self(**data_dict)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
loss, log_vars, num_samples = self(**data)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return outputs
def show_result(self,
**kwargs):
img = None
return img
\ No newline at end of file
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchvision.models.resnet import resnet18, resnet50
from mmdet3d.models.builder import (build_backbone, build_head,
build_neck)
from .base_mapper import BaseMapper, MAPPERS
@MAPPERS.register_module()
class VectorMapNet(BaseMapper):
def __init__(self,
backbone_cfg=dict(),
head_cfg=dict(
vert_net_cfg=dict(),
face_net_cfg=dict(),
),
neck_input_channels=128,
neck_cfg=None,
with_auxiliary_head=False,
only_det=False,
train_cfg=None,
test_cfg=None,
pretrained=None,
model_name=None, **kwargs):
super(VectorMapNet, self).__init__()
#Attribute
self.model_name = model_name
self.last_epoch = None
self.only_det = only_det
self.backbone = build_backbone(backbone_cfg)
if neck_cfg is not None:
self.neck_neck = build_backbone(neck_cfg.backbone)
self.neck_neck.conv1 = nn.Conv2d(
neck_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.neck_project = build_neck(neck_cfg.neck)
self.neck = self.multiscale_neck
else:
trunk = resnet18(pretrained=False, zero_init_residual=True)
self.neck = nn.Sequential(
nn.Conv2d(neck_input_channels, 64, kernel_size=(7, 7), stride=(
2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
dilation=1, ceil_mode=False),
trunk.layer1,
nn.Conv2d(64, 128, kernel_size=1, bias=False),
)
# BEV
if hasattr(self.backbone,'bev_w'):
self.bev_w = self.backbone.bev_w
self.bev_h = self.backbone.bev_h
self.head = build_head(head_cfg)
def multiscale_neck(self, bev_embedding):
multi_feat = self.neck_neck(bev_embedding)
multi_feat = self.neck_project(multi_feat)
return multi_feat
def forward_train(self, img, polys, points=None, img_metas=None, **kwargs):
'''
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- length: int
- label: int
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
Out:
loss, log_vars, num_sample
'''
# prepare labels and images
batch, img, img_metas, valid_idx, points = self.batch_data(
polys, img, img_metas, img.device, points)
# corner cases use hard code to prevent code fail
if self.last_epoch is None:
self.last_epoch = [batch, img, img_metas, valid_idx, points]
if len(valid_idx)==0:
batch, img, img_metas, valid_idx, points = self.last_epoch
else:
del self.last_epoch
self.last_epoch = [batch, img, img_metas, valid_idx, points]
# Backbone
_bev_feats = self.backbone(img, img_metas=img_metas, points=points)
img_shape = \
[_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
preds_dict, losses_dict = \
self.head(batch,
context={
'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape,
'raw_bev_embeddings': _bev_feats},
only_det=self.only_det)
# format outputs
loss = 0
for name, var in losses_dict.items():
loss = loss + var
# update the log
log_vars = {k: v.item() for k, v in losses_dict.items()}
log_vars.update({'total': loss.item()})
num_sample = img.size(0)
return loss, log_vars, num_sample
@torch.no_grad()
def forward_test(self, img, polys=None, points=None, img_metas=None, **kwargs):
'''
inference pipeline
'''
# prepare labels and images
token = []
for img_meta in img_metas:
token.append(img_meta['token'])
_bev_feats = self.backbone(img, img_metas, points=points)
img_shape = [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
context = {'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, # XXX
'raw_bev_embeddings': _bev_feats}
preds_dict = self.head(batch={},
context=context,
condition_on_det=True,
gt_condition=False,
only_det=self.only_det)
# Hard Code
if preds_dict is None:
return [None]
results_list = self.head.post_process(preds_dict, token, only_det=self.only_det)
return results_list
def batch_data(self, polys, imgs, img_metas, device, points=None):
# filter none vector's case
valid_idx = [i for i in range(len(polys)) if len(polys[i])]
imgs = imgs[valid_idx]
img_metas = [img_metas[i] for i in valid_idx]
polys = [polys[i] for i in valid_idx]
if points is not None:
points = [points[i] for i in valid_idx]
points = self.batch_points(points)
if len(valid_idx) == 0:
return None, None, None, valid_idx, None
batch = {}
batch['det'] = format_det(polys,device)
batch['gen'] = format_gen(polys,device)
return batch, imgs, img_metas, valid_idx, points
def batch_points(self, points):
pad_points = pad_sequence(points, batch_first=True)
points_mask = torch.zeros_like(pad_points[:,:,0]).bool()
for i in range(len(points)):
valid_num = points[i].shape[0]
points_mask[i][:valid_num] = True
return (pad_points, points_mask)
def format_det(polys, device):
batch = {
'class_label':[],
'batch_idx':[],
'bbox': [],
}
for batch_idx, poly in enumerate(polys):
keypoint_label = torch.from_numpy(poly['det_label']).to(device)
keypoint = torch.from_numpy(poly['keypoint']).to(device)
batch['class_label'].append(keypoint_label)
batch['bbox'].append(keypoint)
return batch
def format_gen(polys,device):
line_cls = []
polylines, polyline_masks, polyline_weights = [], [], []
bbox, line_cls, line_bs_idx = [], [], []
for batch_idx, poly in enumerate(polys):
# convert to cuda tensor
for k in poly.keys():
if isinstance(poly[k],np.ndarray):
poly[k] = torch.from_numpy(poly[k]).to(device)
else:
poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]]
line_cls += poly['gen_label']
line_bs_idx += [batch_idx]*len(poly['gen_label'])
# condition
bbox += poly['qkeypoint']
# out
polylines += poly['polylines']
polyline_masks += poly['polyline_masks']
polyline_weights += poly['polyline_weights']
batch = {}
batch['lines_bs_idx'] = torch.tensor(
line_bs_idx, dtype=torch.long, device=device)
batch['lines_cls'] = torch.tensor(
line_cls, dtype=torch.long, device=device)
batch['bbox_flat'] = torch.stack(bbox, 0)
# padding
batch['polylines'] = pad_sequence(polylines, batch_first=True)
batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True)
batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True)
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchvision.models.resnet import resnet18, resnet50
from mmdet3d.models.builder import (build_backbone, build_head,
build_neck)
from .base_mapper import BaseMapper, MAPPERS
@MAPPERS.register_module()
class VectorMapNet(BaseMapper):
def __init__(self,
backbone_cfg=dict(),
head_cfg=dict(
vert_net_cfg=dict(),
face_net_cfg=dict(),
),
neck_input_channels=128,
neck_cfg=None,
with_auxiliary_head=False,
only_det=False,
train_cfg=None,
test_cfg=None,
pretrained=None,
model_name=None, **kwargs):
super(VectorMapNet, self).__init__()
#Attribute
self.model_name = model_name
self.last_epoch = None
self.only_det = only_det
self.backbone = build_backbone(backbone_cfg)
if neck_cfg is not None:
self.neck_neck = build_backbone(neck_cfg.backbone)
self.neck_neck.conv1 = nn.Conv2d(
neck_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.neck_project = build_neck(neck_cfg.neck)
self.neck = self.multiscale_neck
else:
trunk = resnet18(pretrained=False, zero_init_residual=True)
self.neck = nn.Sequential(
nn.Conv2d(neck_input_channels, 64, kernel_size=(7, 7), stride=(
2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
dilation=1, ceil_mode=False),
trunk.layer1,
nn.Conv2d(64, 128, kernel_size=1, bias=False),
)
# BEV
if hasattr(self.backbone,'bev_w'):
self.bev_w = self.backbone.bev_w
self.bev_h = self.backbone.bev_h
self.head = build_head(head_cfg)
def multiscale_neck(self, bev_embedding):
multi_feat = self.neck_neck(bev_embedding)
multi_feat = self.neck_project(multi_feat)
return multi_feat
def forward_train(self, img, polys, points=None, img_metas=None, **kwargs):
'''
Args:
img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams
vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2].
- length: int
- label: int
len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b
img_metas:
img_metas['lidar2img']: [B, N, 4, 4]
Out:
loss, log_vars, num_sample
'''
# prepare labels and images
batch, img, img_metas, valid_idx, points = self.batch_data(
polys, img, img_metas, img.device, points)
# corner cases use hard code to prevent code fail
if self.last_epoch is None:
self.last_epoch = [batch, img, img_metas, valid_idx, points]
if len(valid_idx)==0:
batch, img, img_metas, valid_idx, points = self.last_epoch
else:
del self.last_epoch
self.last_epoch = [batch, img, img_metas, valid_idx, points]
# Backbone
_bev_feats = self.backbone(img, img_metas=img_metas, points=points)
img_shape = \
[_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
preds_dict, losses_dict = \
self.head(batch,
context={
'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape,
'raw_bev_embeddings': _bev_feats},
only_det=self.only_det)
# format outputs
loss = 0
for name, var in losses_dict.items():
loss = loss + var
# update the log
log_vars = {k: v.item() for k, v in losses_dict.items()}
log_vars.update({'total': loss.item()})
num_sample = img.size(0)
return loss, log_vars, num_sample
@torch.no_grad()
def forward_test(self, img, polys=None, points=None, img_metas=None, **kwargs):
'''
inference pipeline
'''
# prepare labels and images
token = []
for img_meta in img_metas:
token.append(img_meta['token'])
_bev_feats = self.backbone(img, img_metas, points=points)
img_shape = [_bev_feats.shape[2:] for i in range(_bev_feats.shape[0])]
# Neck
bev_feats = self.neck(_bev_feats)
context = {'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, # XXX
'raw_bev_embeddings': _bev_feats}
preds_dict = self.head(batch={},
context=context,
condition_on_det=True,
gt_condition=False,
only_det=self.only_det)
# Hard Code
if preds_dict is None:
return [None]
results_list = self.head.post_process(preds_dict, token, only_det=self.only_det)
return results_list
def batch_data(self, polys, imgs, img_metas, device, points=None):
# filter none vector's case
valid_idx = [i for i in range(len(polys)) if len(polys[i])]
imgs = imgs[valid_idx]
img_metas = [img_metas[i] for i in valid_idx]
polys = [polys[i] for i in valid_idx]
if points is not None:
points = [points[i] for i in valid_idx]
points = self.batch_points(points)
if len(valid_idx) == 0:
return None, None, None, valid_idx, None
batch = {}
batch['det'] = format_det(polys,device)
batch['gen'] = format_gen(polys,device)
return batch, imgs, img_metas, valid_idx, points
def batch_points(self, points):
pad_points = pad_sequence(points, batch_first=True)
points_mask = torch.zeros_like(pad_points[:,:,0]).bool()
for i in range(len(points)):
valid_num = points[i].shape[0]
points_mask[i][:valid_num] = True
return (pad_points, points_mask)
def format_det(polys, device):
batch = {
'class_label':[],
'batch_idx':[],
'bbox': [],
}
for batch_idx, poly in enumerate(polys):
keypoint_label = torch.from_numpy(poly['det_label']).to(device)
keypoint = torch.from_numpy(poly['keypoint']).to(device)
batch['class_label'].append(keypoint_label)
batch['bbox'].append(keypoint)
return batch
def format_gen(polys,device):
line_cls = []
polylines, polyline_masks, polyline_weights = [], [], []
bbox, line_cls, line_bs_idx = [], [], []
for batch_idx, poly in enumerate(polys):
# convert to cuda tensor
for k in poly.keys():
if isinstance(poly[k],np.ndarray):
poly[k] = torch.from_numpy(poly[k]).to(device)
else:
poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]]
line_cls += poly['gen_label']
line_bs_idx += [batch_idx]*len(poly['gen_label'])
# condition
bbox += poly['qkeypoint']
# out
polylines += poly['polylines']
polyline_masks += poly['polyline_masks']
polyline_weights += poly['polyline_weights']
batch = {}
batch['lines_bs_idx'] = torch.tensor(
line_bs_idx, dtype=torch.long, device=device)
batch['lines_cls'] = torch.tensor(
line_cls, dtype=torch.long, device=device)
batch['bbox_flat'] = torch.stack(bbox, 0)
# padding
batch['polylines'] = pad_sequence(polylines, batch_first=True)
batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True)
batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True)
return batch
\ No newline at end of file
from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_
from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_
from .base_transformer import PlaceHolderEncoder
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class PlaceHolderEncoder(nn.Module):
def __init__(self, *args, embed_dims=None, **kwargs):
super(PlaceHolderEncoder, self).__init__()
self.embed_dims = embed_dims
def forward(self, *args, query=None, **kwargs):
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class PlaceHolderEncoder(nn.Module):
def __init__(self, *args, embed_dims=None, **kwargs):
super(PlaceHolderEncoder, self).__init__()
self.embed_dims = embed_dims
def forward(self, *args, query=None, **kwargs):
return query
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from torch.nn.init import normal_
from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import Transformer
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from .fp16_dattn import MultiScaleDeformableAttentionFp16
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args,
return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs):
super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.coord_dim = coord_dim
self.kp_coord_dim = kp_coord_dim
def forward(self,
query,
*args,
reference_points=None,
valid_ratios=None,
reg_branches=None,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = \
reference_points[:, :, None,:self.kp_coord_dim] * \
valid_ratios[:, None,:,:self.kp_coord_dim]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output = layer(
output,
*args,
reference_points=reference_points_input[...,:self.kp_coord_dim],
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
new_reference_points = tmp
new_reference_points[..., :self.kp_coord_dim] = tmp[
..., :self.kp_coord_dim] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
reference_points[...,-1] = tmp[...,-1].sigmoid().detach()
reference_points[...,:self.coord_dim] = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER.register_module()
class DeformableDetrTransformer_(Transformer):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
as_two_stage=False,
num_feature_levels=1,
two_stage_num_proposals=300,
coord_dim=2,
**kwargs):
super(DeformableDetrTransformer_, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims
self.coord_dim = coord_dim
self.init_layers()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2,
self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else:
self.reference_points_embed = nn.Linear(self.embed_dims, self.coord_dim)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
elif isinstance(m,MultiScaleDeformableAttentionFp16):
m.init_weights()
if not self.as_two_stage:
xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)
normal_(self.level_embeds)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self,
proposals,
num_pos_feats=128,
temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
dim=4).flatten(2)
return pos
def forward(self,
mlvl_feats,
mlvl_masks,
query_embed,
mlvl_pos_embeds,
reg_branches=None,
cls_branches=None,
**kwargs):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
# reference_points = \
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# query=feat_flatten,
# key=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# **kwargs)
memory = feat_flatten.permute(1, 0, 2)
bs, _, c = memory.shape
query_pos, query = torch.split(query_embed, c, dim=-1)
reference_points = self.reference_points_embed(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from torch.nn.init import normal_
from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import Transformer
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from .fp16_dattn import MultiScaleDeformableAttentionFp16
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args,
return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs):
super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.coord_dim = coord_dim
self.kp_coord_dim = kp_coord_dim
def forward(self,
query,
*args,
reference_points=None,
valid_ratios=None,
reg_branches=None,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = \
reference_points[:, :, None,:self.kp_coord_dim] * \
valid_ratios[:, None,:,:self.kp_coord_dim]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output = layer(
output,
*args,
reference_points=reference_points_input[...,:self.kp_coord_dim],
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
new_reference_points = tmp
new_reference_points[..., :self.kp_coord_dim] = tmp[
..., :self.kp_coord_dim] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
reference_points[...,-1] = tmp[...,-1].sigmoid().detach()
reference_points[...,:self.coord_dim] = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER.register_module()
class DeformableDetrTransformer_(Transformer):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
as_two_stage=False,
num_feature_levels=1,
two_stage_num_proposals=300,
coord_dim=2,
**kwargs):
super(DeformableDetrTransformer_, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims
self.coord_dim = coord_dim
self.init_layers()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2,
self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else:
self.reference_points_embed = nn.Linear(self.embed_dims, self.coord_dim)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
elif isinstance(m,MultiScaleDeformableAttentionFp16):
m.init_weights()
if not self.as_two_stage:
xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)
normal_(self.level_embeds)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self,
proposals,
num_pos_feats=128,
temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
dim=4).flatten(2)
return pos
def forward(self,
mlvl_feats,
mlvl_masks,
query_embed,
mlvl_pos_embeds,
reg_branches=None,
cls_branches=None,
**kwargs):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
# reference_points = \
# self.get_reference_points(spatial_shapes,
# valid_ratios,
# device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
# lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
# 1, 0, 2) # (H*W, bs, embed_dims)
# memory = self.encoder(
# query=feat_flatten,
# key=None,
# value=None,
# query_pos=lvl_pos_embed_flatten,
# query_key_padding_mask=mask_flatten,
# spatial_shapes=spatial_shapes,
# reference_points=reference_points,
# level_start_index=level_start_index,
# valid_ratios=valid_ratios,
# **kwargs)
memory = feat_flatten.permute(1, 0, 2)
bs, _, c = memory.shape
query_pos, query = torch.split(query_embed, c, dim=-1)
reference_points = self.reference_points_embed(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
\ No newline at end of file
from turtle import forward
import warnings
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_attention
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from torch.cuda.amp import custom_bwd, custom_fwd
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFp16(BaseModule):
def __init__(self, attn_cfg=None,init_cfg=None,**kwarg):
super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg)
# import ipdb; ipdb.set_trace()
self.deformable_attention = build_attention(attn_cfg)
self.deformable_attention.init_weights()
self.fp16_enabled = False
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity'))
def forward(self, query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
# import ipdb; ipdb.set_trace()
return self.deformable_attention(query,
key=key,
value=value,
identity=identity,
query_pos=query_pos,
key_padding_mask=key_padding_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,**kwargs)
class MultiScaleDeformableAttnFunctionFp32(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFP32(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunctionFp32.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
from turtle import forward
import warnings
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
'`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_attention
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from torch.cuda.amp import custom_bwd, custom_fwd
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFp16(BaseModule):
def __init__(self, attn_cfg=None,init_cfg=None,**kwarg):
super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg)
# import ipdb; ipdb.set_trace()
self.deformable_attention = build_attention(attn_cfg)
self.deformable_attention.init_weights()
self.fp16_enabled = False
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity'))
def forward(self, query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
# import ipdb; ipdb.set_trace()
return self.deformable_attention(query,
key=key,
value=value,
identity=identity,
query_pos=query_pos,
key_padding_mask=key_padding_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,**kwargs)
class MultiScaleDeformableAttnFunctionFp32(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
class MultiScaleDeformableAttentionFP32(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunctionFp32.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
\ No newline at end of file
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
import sys
import os
sys.path.append(os.path.abspath('.'))
from src.datasets.evaluation.vector_eval import VectorEvaluate
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description='Evaluate a submission file')
parser.add_argument('submission',
help='submission file in pickle or json format to be evaluated')
parser.add_argument('gt',
help='gt annotation file')
args = parser.parse_args()
return args
def main(args):
evaluator = VectorEvaluate(args.gt, n_workers=0)
results = evaluator.evaluate(args.submission)
print(results)
if __name__ == '__main__':
args = parse_args()
main(args)
import sys
import os
sys.path.append(os.path.abspath('.'))
from src.datasets.evaluation.vector_eval import VectorEvaluate
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description='Evaluate a submission file')
parser.add_argument('submission',
help='submission file in pickle or json format to be evaluated')
parser.add_argument('gt',
help='gt annotation file')
args = parser.parse_args()
return args
def main(args):
evaluator = VectorEvaluate(args.gt, n_workers=0)
results = evaluator.evaluate(args.submission)
print(results)
if __name__ == '__main__':
args = parse_args()
main(args)
import os.path as osp
import pickle
import shutil
import tempfile
import time
import mmcv
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
# encode mask results
# if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result]
results.extend(result)
if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
def collect_results_cpu(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def collect_results_gpu(result_part, size):
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
import os.path as osp
import pickle
import shutil
import tempfile
import time
import mmcv
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
# encode mask results
# if isinstance(result[0], tuple):
# result = [(bbox_results, encode_mask_results(mask_results))
# for bbox_results, mask_results in result]
results.extend(result)
if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
def collect_results_cpu(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def collect_results_gpu(result_part, size):
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import get_root_logger
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
# Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import get_root_logger
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
# Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
import argparse
import mmcv
import os
import os.path as osp
import torch
import warnings
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmdet_test import multi_gpu_test
from mmdet_train import set_random_seed
from mmdet.datasets import replace_ImageToTensor
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('--split', type=str, required=True, help='which split to test on')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
action='store_true',
help='whether to run evaluation.')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
if args.split not in ['val', 'test']:
raise ValueError('Please choose "val" or "test" split for testing')
if (args.eval and args.format_only) or (not args.eval and not args.format_only):
raise ValueError('Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"')
if args.eval and args.split == 'test':
raise ValueError('Cannot evaluate on test set')
cfg = Config.fromfile(args.config)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# import modules from plguin/xx, registry will be updated
import sys
sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'):
if cfg.plugin:
import importlib
if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs:
import_path(plugin_dir)
else:
# import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
cfg_data_dict = cfg.data.get(args.split)
cfg.model.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
cfg_data_dict.test_mode = True
samples_per_gpu = cfg_data_dict.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict.pipeline = replace_ImageToTensor(
cfg_data_dict.pipeline)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg_data_dict.work_dir = cfg.work_dir
print('work_dir: ',cfg.work_dir)
dataset = build_dataset(cfg_data_dict)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.format_only:
dataset.format_results(outputs, prefix=cfg.work_dir)
elif args.eval:
print('start evaluation!')
print(dataset.evaluate(outputs))
if __name__ == '__main__':
main()
import argparse
import mmcv
import os
import os.path as osp
import torch
import warnings
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmdet_test import multi_gpu_test
from mmdet_train import set_random_seed
from mmdet.datasets import replace_ImageToTensor
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('--split', type=str, required=True, help='which split to test on')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
action='store_true',
help='whether to run evaluation.')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
if args.split not in ['val', 'test']:
raise ValueError('Please choose "val" or "test" split for testing')
if (args.eval and args.format_only) or (not args.eval and not args.format_only):
raise ValueError('Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"')
if args.eval and args.split == 'test':
raise ValueError('Cannot evaluate on test set')
cfg = Config.fromfile(args.config)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# import modules from plguin/xx, registry will be updated
import sys
sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'):
if cfg.plugin:
import importlib
if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs:
import_path(plugin_dir)
else:
# import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
cfg_data_dict = cfg.data.get(args.split)
cfg.model.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
cfg_data_dict.test_mode = True
samples_per_gpu = cfg_data_dict.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg_data_dict.pipeline = replace_ImageToTensor(
cfg_data_dict.pipeline)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg_data_dict.work_dir = cfg.work_dir
print('work_dir: ',cfg.work_dir)
dataset = build_dataset(cfg_data_dict)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.format_only:
dataset.format_results(outputs, prefix=cfg.work_dir)
elif args.eval:
print('start evaluation!')
print(dataset.evaluate(outputs))
if __name__ == '__main__':
main()
from __future__ import division
import argparse
import copy
import mmcv
import os
import time
import torch
import warnings
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from os import path as osp
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.utils import collect_env, get_root_logger
from mmseg import __version__ as mmseg_version
# warper
from mmdet_train import set_random_seed
# from builder import build_model
from mmdet3d.models import build_model
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--autoscale-lr',
action='store_true',
help='automatically scale lr with the number of gpus')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# import modules, registry will be updated
import sys
sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'):
if cfg.plugin:
import importlib
if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs:
import_path(plugin_dir)
else:
# import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(f'Model:\n{model}')
cfg.data.train.work_dir = cfg.work_dir
cfg.data.val.work_dir = cfg.work_dir
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
# in case we use a dataset wrapper
if 'dataset' in cfg.data.train:
val_dataset.pipeline = cfg.data.train.dataset.pipeline
else:
val_dataset.pipeline = cfg.data.train.pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset.test_mode = False
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=mmdet_version,
mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text,
CLASSES=None,
PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
from __future__ import division
import argparse
import copy
import mmcv
import os
import time
import torch
import warnings
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from os import path as osp
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.utils import collect_env, get_root_logger
from mmseg import __version__ as mmseg_version
# warper
from mmdet_train import set_random_seed
# from builder import build_model
from mmdet3d.models import build_model
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--autoscale-lr',
action='store_true',
help='automatically scale lr with the number of gpus')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# import modules, registry will be updated
import sys
sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'):
if cfg.plugin:
import importlib
if hasattr(cfg, 'plugin_dir'):
def import_path(plugin_dir):
_module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list):
plugin_dirs = [plugin_dirs,]
for plugin_dir in plugin_dirs:
import_path(plugin_dir)
else:
# import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
print(f'importing {_module_path}/')
plg_lib = importlib.import_module(_module_path)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(f'Model:\n{model}')
cfg.data.train.work_dir = cfg.work_dir
cfg.data.val.work_dir = cfg.work_dir
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
# in case we use a dataset wrapper
if 'dataset' in cfg.data.train:
val_dataset.pipeline = cfg.data.train.dataset.pipeline
else:
val_dataset.pipeline = cfg.data.train.pipeline
# set test_mode=False here in deep copied config
# which do not affect AP/AR calculation later
# refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
val_dataset.test_mode = False
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=mmdet_version,
mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text,
CLASSES=None,
PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience
# model.CLASSES = datasets[0].CLASSES
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
import os.path as osp
import os
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from shapely.geometry import LineString
def remove_nan_values(uv):
is_u_valid = np.logical_not(np.isnan(uv[:, 0]))
is_v_valid = np.logical_not(np.isnan(uv[:, 1]))
is_uv_valid = np.logical_and(is_u_valid, is_v_valid)
uv_valid = uv[is_uv_valid]
return uv_valid
def points_ego2img(pts_ego, extrinsics, intrinsics):
pts_ego_4d = np.concatenate([pts_ego, np.ones([len(pts_ego), 1])], axis=-1)
pts_cam_4d = extrinsics @ pts_ego_4d.T
uv = (intrinsics @ pts_cam_4d[:3, :]).T
uv = remove_nan_values(uv)
depth = uv[:, 2]
uv = uv[:, :2] / uv[:, 2].reshape(-1, 1)
return uv, depth
def interp_fixed_dist(line, sample_dist):
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances = list(np.arange(sample_dist, line.length, sample_dist))
# make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def draw_polyline_ego_on_img(polyline_ego, img_bgr, extrinsics, intrinsics, color_bgr, thickness):
# if 2-dimension, assume z=0
if polyline_ego.shape[1] == 2:
zeros = np.zeros((polyline_ego.shape[0], 1))
polyline_ego = np.concatenate([polyline_ego, zeros], axis=1)
polyline_ego = interp_fixed_dist(line=LineString(polyline_ego), sample_dist=0.2)
uv, depth = points_ego2img(polyline_ego, extrinsics, intrinsics)
h, w, c = img_bgr.shape
is_valid_x = np.logical_and(0 <= uv[:, 0], uv[:, 0] < w - 1)
is_valid_y = np.logical_and(0 <= uv[:, 1], uv[:, 1] < h - 1)
is_valid_z = depth > 0
is_valid_points = np.logical_and.reduce([is_valid_x, is_valid_y, is_valid_z])
if is_valid_points.sum() == 0:
return
tmp_list = []
for i, valid in enumerate(is_valid_points):
if valid:
tmp_list.append(uv[i])
else:
if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr,
color=color_bgr,
thickness_px=thickness,
)
tmp_list = []
if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr,
color=color_bgr,
thickness_px=thickness,
)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2(
# copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr,
# color=color_bgr,
# thickness_px=thickness,
# )
def draw_visible_polyline_cv2(line, valid_pts_bool, image, color, thickness_px):
"""Draw a polyline onto an image using given line segments.
Args:
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
line = np.round(line).astype(int) # type: ignore
for i in range(len(line) - 1):
if (not valid_pts_bool[i]) or (not valid_pts_bool[i + 1]):
continue
x1 = line[i][0]
y1 = line[i][1]
x2 = line[i + 1][0]
y2 = line[i + 1][1]
# Use anti-aliasing (AA) for curves
image = cv2.line(image, pt1=(x1, y1), pt2=(x2, y2), color=color, thickness=thickness_px, lineType=cv2.LINE_AA)
COLOR_MAPS_BGR = {
# bgr colors
'divider': (0, 0, 255),
'boundary': (0, 255, 0),
'ped_crossing': (255, 0, 0),
'centerline': (51, 183, 255),
'drivable_area': (171, 255, 255)
}
COLOR_MAPS_PLT = {
'divider': 'r',
'boundary': 'g',
'ped_crossing': 'b',
'centerline': 'orange',
'drivable_area': 'y',
}
CAM_NAMES_AV2 = ['ring_front_center', 'ring_front_right', 'ring_front_left',
'ring_rear_right','ring_rear_left', 'ring_side_right', 'ring_side_left',
]
class Renderer(object):
"""Render map elements on image views.
Args:
roi_size (tuple): bev range
"""
def __init__(self, roi_size):
self.roi_size = roi_size
def render_bev_from_vectors(self, vectors, out_dir):
'''Plot vectorized map elements on BEV.
Args:
vectors (dict): dict of vectorized map elements.
out_dir (str): output directory
'''
car_img = Image.open('resources/images/car.png')
map_path = os.path.join(out_dir, 'map.jpg')
plt.figure(figsize=(self.roi_size[0], self.roi_size[1]))
plt.xlim(-self.roi_size[0]/2 - 1, self.roi_size[0]/2 + 1)
plt.ylim(-self.roi_size[1]/2 - 1, self.roi_size[1]/2 + 1)
plt.axis('off')
plt.imshow(car_img, extent=[-1.5, 1.5, -1.2, 1.2])
for cat, vector_list in vectors.items():
color = COLOR_MAPS_PLT[cat]
for vector in vector_list:
pts = np.array(vector)[:, :2]
x = np.array([pt[0] for pt in pts])
y = np.array([pt[1] for pt in pts])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1)
plt.plot(x, y, color=color, linewidth=5, marker='o', linestyle='-', markersize=20)
plt.savefig(map_path, bbox_inches='tight', dpi=40)
plt.close()
def render_camera_views_from_vectors(self, vectors, imgs, extrinsics,
intrinsics, thickness, out_dir):
'''Project vectorized map elements to camera views.
Args:
vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images.
out_dir (str): output directory
'''
for i in range(len(imgs)):
img = imgs[i]
extrinsic = extrinsics[i]
intrinsic = intrinsics[i]
img_bgr = copy.deepcopy(img)
for cat, vector_list in vectors.items():
color = COLOR_MAPS_BGR[cat]
for vector in vector_list:
img_bgr = np.ascontiguousarray(img_bgr)
vector_array = np.array(vector)
if vector_array.shape[1] > 3:
vector_array = vector_array[:, :3]
draw_polyline_ego_on_img(vector_array, img_bgr, extrinsic, intrinsic,
color, thickness)
out_path = osp.join(out_dir, CAM_NAMES_AV2[i]) + '.jpg'
cv2.imwrite(out_path, img_bgr)
import os.path as osp
import os
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from shapely.geometry import LineString
def remove_nan_values(uv):
is_u_valid = np.logical_not(np.isnan(uv[:, 0]))
is_v_valid = np.logical_not(np.isnan(uv[:, 1]))
is_uv_valid = np.logical_and(is_u_valid, is_v_valid)
uv_valid = uv[is_uv_valid]
return uv_valid
def points_ego2img(pts_ego, extrinsics, intrinsics):
pts_ego_4d = np.concatenate([pts_ego, np.ones([len(pts_ego), 1])], axis=-1)
pts_cam_4d = extrinsics @ pts_ego_4d.T
uv = (intrinsics @ pts_cam_4d[:3, :]).T
uv = remove_nan_values(uv)
depth = uv[:, 2]
uv = uv[:, :2] / uv[:, 2].reshape(-1, 1)
return uv, depth
def interp_fixed_dist(line, sample_dist):
''' Interpolate a line at fixed interval.
Args:
line (LineString): line
sample_dist (float): sample interval
Returns:
points (array): interpolated points, shape (N, 2)
'''
distances = list(np.arange(sample_dist, line.length, sample_dist))
# make sure to sample at least two points when sample_dist > line.length
distances = [0,] + distances + [line.length,]
sampled_points = np.array([list(line.interpolate(distance).coords)
for distance in distances]).squeeze()
return sampled_points
def draw_polyline_ego_on_img(polyline_ego, img_bgr, extrinsics, intrinsics, color_bgr, thickness):
# if 2-dimension, assume z=0
if polyline_ego.shape[1] == 2:
zeros = np.zeros((polyline_ego.shape[0], 1))
polyline_ego = np.concatenate([polyline_ego, zeros], axis=1)
polyline_ego = interp_fixed_dist(line=LineString(polyline_ego), sample_dist=0.2)
uv, depth = points_ego2img(polyline_ego, extrinsics, intrinsics)
h, w, c = img_bgr.shape
is_valid_x = np.logical_and(0 <= uv[:, 0], uv[:, 0] < w - 1)
is_valid_y = np.logical_and(0 <= uv[:, 1], uv[:, 1] < h - 1)
is_valid_z = depth > 0
is_valid_points = np.logical_and.reduce([is_valid_x, is_valid_y, is_valid_z])
if is_valid_points.sum() == 0:
return
tmp_list = []
for i, valid in enumerate(is_valid_points):
if valid:
tmp_list.append(uv[i])
else:
if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr,
color=color_bgr,
thickness_px=thickness,
)
tmp_list = []
if len(tmp_list) >= 2:
tmp_vector = np.stack(tmp_list)
tmp_vector = np.round(tmp_vector).astype(np.int32)
draw_visible_polyline_cv2(
copy.deepcopy(tmp_vector),
valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
image=img_bgr,
color=color_bgr,
thickness_px=thickness,
)
# uv = np.round(uv[is_valid_points]).astype(np.int32)
# draw_visible_polyline_cv2(
# copy.deepcopy(uv),
# valid_pts_bool=np.ones((len(uv), 1), dtype=bool),
# image=img_bgr,
# color=color_bgr,
# thickness_px=thickness,
# )
def draw_visible_polyline_cv2(line, valid_pts_bool, image, color, thickness_px):
"""Draw a polyline onto an image using given line segments.
Args:
line: Array of shape (K, 2) representing the coordinates of line.
valid_pts_bool: Array of shape (K,) representing which polyline coordinates are valid for rendering.
For example, if the coordinate is occluded, a user might specify that it is invalid.
Line segments touching an invalid vertex will not be rendered.
image: Array of shape (H, W, 3), representing a 3-channel BGR image
color: Tuple of shape (3,) with a BGR format color
thickness_px: thickness (in pixels) to use when rendering the polyline.
"""
line = np.round(line).astype(int) # type: ignore
for i in range(len(line) - 1):
if (not valid_pts_bool[i]) or (not valid_pts_bool[i + 1]):
continue
x1 = line[i][0]
y1 = line[i][1]
x2 = line[i + 1][0]
y2 = line[i + 1][1]
# Use anti-aliasing (AA) for curves
image = cv2.line(image, pt1=(x1, y1), pt2=(x2, y2), color=color, thickness=thickness_px, lineType=cv2.LINE_AA)
COLOR_MAPS_BGR = {
# bgr colors
'divider': (0, 0, 255),
'boundary': (0, 255, 0),
'ped_crossing': (255, 0, 0),
'centerline': (51, 183, 255),
'drivable_area': (171, 255, 255)
}
COLOR_MAPS_PLT = {
'divider': 'r',
'boundary': 'g',
'ped_crossing': 'b',
'centerline': 'orange',
'drivable_area': 'y',
}
CAM_NAMES_AV2 = ['ring_front_center', 'ring_front_right', 'ring_front_left',
'ring_rear_right','ring_rear_left', 'ring_side_right', 'ring_side_left',
]
class Renderer(object):
"""Render map elements on image views.
Args:
roi_size (tuple): bev range
"""
def __init__(self, roi_size):
self.roi_size = roi_size
def render_bev_from_vectors(self, vectors, out_dir):
'''Plot vectorized map elements on BEV.
Args:
vectors (dict): dict of vectorized map elements.
out_dir (str): output directory
'''
car_img = Image.open('resources/images/car.png')
map_path = os.path.join(out_dir, 'map.jpg')
plt.figure(figsize=(self.roi_size[0], self.roi_size[1]))
plt.xlim(-self.roi_size[0]/2 - 1, self.roi_size[0]/2 + 1)
plt.ylim(-self.roi_size[1]/2 - 1, self.roi_size[1]/2 + 1)
plt.axis('off')
plt.imshow(car_img, extent=[-1.5, 1.5, -1.2, 1.2])
for cat, vector_list in vectors.items():
color = COLOR_MAPS_PLT[cat]
for vector in vector_list:
pts = np.array(vector)[:, :2]
x = np.array([pt[0] for pt in pts])
y = np.array([pt[1] for pt in pts])
# plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], angles='xy', color=color,
# scale_units='xy', scale=1)
plt.plot(x, y, color=color, linewidth=5, marker='o', linestyle='-', markersize=20)
plt.savefig(map_path, bbox_inches='tight', dpi=40)
plt.close()
def render_camera_views_from_vectors(self, vectors, imgs, extrinsics,
intrinsics, thickness, out_dir):
'''Project vectorized map elements to camera views.
Args:
vectors (dict): dict of vectorized map elements.
imgs (tensor): images in bgr color.
extrinsics (array): ego2img extrinsics, shape (4, 4)
intrinsics (array): intrinsics, shape (3, 3)
thickness (int): thickness of lines to draw on images.
out_dir (str): output directory
'''
for i in range(len(imgs)):
img = imgs[i]
extrinsic = extrinsics[i]
intrinsic = intrinsics[i]
img_bgr = copy.deepcopy(img)
for cat, vector_list in vectors.items():
color = COLOR_MAPS_BGR[cat]
for vector in vector_list:
img_bgr = np.ascontiguousarray(img_bgr)
vector_array = np.array(vector)
if vector_array.shape[1] > 3:
vector_array = vector_array[:, :3]
draw_polyline_ego_on_img(vector_array, img_bgr, extrinsic, intrinsic,
color, thickness)
out_path = osp.join(out_dir, CAM_NAMES_AV2[i]) + '.jpg'
cv2.imwrite(out_path, img_bgr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment