Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
# ---------------------------------------------
# --------------------------------------------- # Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved. # ---------------------------------------------
# --------------------------------------------- # Modified by Zhiqi Li
# Modified by Zhiqi Li # ---------------------------------------------
# ---------------------------------------------
import math
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch import warnings
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.registry import (ATTENTION, from mmcv.cnn.bricks.transformer import build_attention
TRANSFORMER_LAYER, from mmcv.ops.multi_scale_deform_attn import \
TRANSFORMER_LAYER_SEQUENCE) multi_scale_deformable_attn_pytorch
from mmcv.cnn.bricks.transformer import build_attention from mmcv.runner import force_fp32
import math from mmcv.runner.base_module import BaseModule
from mmcv.runner import force_fp32, auto_fp16 from mmcv.utils import ext_loader
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \ ext_module = ext_loader.load_ext(
MultiScaleDeformableAttnFunction_fp16 '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@ATTENTION.register_module()
class SpatialCrossAttention(BaseModule):
@ATTENTION.register_module() """An attention module used in BEVFormer.
class SpatialCrossAttention(BaseModule): Args:
"""An attention module used in BEVFormer. embed_dims (int): The embedding dimension of Attention.
Args: Default: 256.
embed_dims (int): The embedding dimension of Attention. num_cams (int): The number of cameras
Default: 256. dropout (float): A Dropout layer on `inp_residual`.
num_cams (int): The number of cameras Default: 0..
dropout (float): A Dropout layer on `inp_residual`. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: 0.. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. deformable_attention: (dict): The config for the deformable attention used in SCA.
Default: None. """
deformable_attention: (dict): The config for the deformable attention used in SCA.
""" def __init__(self,
embed_dims=256,
def __init__(self, num_cams=6,
embed_dims=256, pc_range=None,
num_cams=6, dropout=0.1,
pc_range=None, init_cfg=None,
dropout=0.1, batch_first=False,
init_cfg=None, deformable_attention=dict(
batch_first=False, type='MSDeformableAttention3D',
deformable_attention=dict( embed_dims=256,
type='MSDeformableAttention3D', num_levels=4),
embed_dims=256, **kwargs
num_levels=4), ):
**kwargs super(SpatialCrossAttention, self).__init__(init_cfg)
):
super(SpatialCrossAttention, self).__init__(init_cfg) self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.init_cfg = init_cfg self.pc_range = pc_range
self.dropout = nn.Dropout(dropout) self.fp16_enabled = False
self.pc_range = pc_range self.deformable_attention = build_attention(deformable_attention)
self.fp16_enabled = False self.embed_dims = embed_dims
self.deformable_attention = build_attention(deformable_attention) self.num_cams = num_cams
self.embed_dims = embed_dims self.output_proj = nn.Linear(embed_dims, embed_dims)
self.num_cams = num_cams self.batch_first = batch_first
self.output_proj = nn.Linear(embed_dims, embed_dims) self.init_weight()
self.batch_first = batch_first
self.init_weight() def init_weight(self):
"""Default initialization for Parameters of Module."""
def init_weight(self): xavier_init(self.output_proj, distribution='uniform', bias=0.)
"""Default initialization for Parameters of Module."""
xavier_init(self.output_proj, distribution='uniform', bias=0.) @force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam'))
def forward(self,
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam')) query,
def forward(self, key,
query, value,
key, residual=None,
value, query_pos=None,
residual=None, key_padding_mask=None,
query_pos=None, reference_points=None,
key_padding_mask=None, spatial_shapes=None,
reference_points=None, reference_points_cam=None,
spatial_shapes=None, bev_mask=None,
reference_points_cam=None, level_start_index=None,
bev_mask=None, flag='encoder',
level_start_index=None, **kwargs):
flag='encoder', """Forward Function of Detr3DCrossAtten.
**kwargs): Args:
"""Forward Function of Detr3DCrossAtten. query (Tensor): Query of Transformer with shape
Args: (num_query, bs, embed_dims).
query (Tensor): Query of Transformer with shape key (Tensor): The key tensor with shape
(num_query, bs, embed_dims). `(num_key, bs, embed_dims)`.
key (Tensor): The key tensor with shape value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. `(num_key, bs, embed_dims)`. (B, N, C, H, W)
value (Tensor): The value tensor with shape residual (Tensor): The tensor used for addition, with the
`(num_key, bs, embed_dims)`. (B, N, C, H, W) same shape as `x`. Default None. If None, `x` will be used.
residual (Tensor): The tensor used for addition, with the query_pos (Tensor): The positional encoding for `query`.
same shape as `x`. Default None. If None, `x` will be used. Default: None.
query_pos (Tensor): The positional encoding for `query`. key_pos (Tensor): The positional encoding for `key`. Default
Default: None. None.
key_pos (Tensor): The positional encoding for `key`. Default reference_points (Tensor): The normalized reference
None. points with shape (bs, num_query, 4),
reference_points (Tensor): The normalized reference all elements is range in [0, 1], top-left (0,0),
points with shape (bs, num_query, 4), bottom-right (1, 1), including padding area.
all elements is range in [0, 1], top-left (0,0), or (N, Length_{query}, num_levels, 4), add
bottom-right (1, 1), including padding area. additional two dimensions is (w, h) to
or (N, Length_{query}, num_levels, 4), add form reference boxes.
additional two dimensions is (w, h) to key_padding_mask (Tensor): ByteTensor for `query`, with
form reference boxes. shape [bs, num_key].
key_padding_mask (Tensor): ByteTensor for `query`, with spatial_shapes (Tensor): Spatial shape of features in
shape [bs, num_key]. different level. With shape (num_levels, 2),
spatial_shapes (Tensor): Spatial shape of features in last dimension represent (h, w).
different level. With shape (num_levels, 2), level_start_index (Tensor): The start index of each level.
last dimension represent (h, w). A tensor has shape (num_levels) and can be represented
level_start_index (Tensor): The start index of each level. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
A tensor has shape (num_levels) and can be represented Returns:
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. Tensor: forwarded results with shape [num_query, bs, embed_dims].
Returns: """
Tensor: forwarded results with shape [num_query, bs, embed_dims].
""" if key is None:
key = query
if key is None: if value is None:
key = query value = key
if value is None:
value = key if residual is None:
inp_residual = query
if residual is None: slots = torch.zeros_like(query)
inp_residual = query if query_pos is not None:
slots = torch.zeros_like(query) query = query + query_pos
if query_pos is not None:
query = query + query_pos bs, num_query, _ = query.size()
bs, num_query, _ = query.size() D = reference_points_cam.size(3)
indexes = []
D = reference_points_cam.size(3) for i, mask_per_img in enumerate(bev_mask):
indexes = [] index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
for i, mask_per_img in enumerate(bev_mask): indexes.append(index_query_per_img)
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1) max_len = max([len(each) for each in indexes])
indexes.append(index_query_per_img)
max_len = max([len(each) for each in indexes]) # each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch = query.new_zeros(
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory. [bs, self.num_cams, max_len, self.embed_dims])
queries_rebatch = query.new_zeros( reference_points_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, self.embed_dims]) [bs, self.num_cams, max_len, D, 2])
reference_points_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, D, 2]) for j in range(bs):
for i, reference_points_per_img in enumerate(reference_points_cam):
for j in range(bs): index_query_per_img = indexes[i]
for i, reference_points_per_img in enumerate(reference_points_cam): queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
index_query_per_img = indexes[i] reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img] j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
num_cams, l, bs, embed_dims = key.shape
num_cams, l, bs, embed_dims = key.shape
key = key.permute(2, 0, 1, 3).reshape(
key = key.permute(2, 0, 1, 3).reshape( bs * self.num_cams, l, self.embed_dims)
bs * self.num_cams, l, self.embed_dims) value = value.permute(2, 0, 1, 3).reshape(
value = value.permute(2, 0, 1, 3).reshape( bs * self.num_cams, l, self.embed_dims)
bs * self.num_cams, l, self.embed_dims)
queries = self.deformable_attention(query=queries_rebatch.view(bs * self.num_cams, max_len, self.embed_dims),
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value, key=key, value=value,
reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes, reference_points=reference_points_rebatch.view(bs * self.num_cams, max_len,
level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims) D, 2),
for j in range(bs): spatial_shapes=spatial_shapes,
for i, index_query_per_img in enumerate(indexes): level_start_index=level_start_index).view(bs, self.num_cams, max_len,
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)] self.embed_dims)
for j in range(bs):
count = bev_mask.sum(-1) > 0 for i, index_query_per_img in enumerate(indexes):
count = count.permute(1, 2, 0).sum(-1) slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None] count = bev_mask.sum(-1) > 0
slots = self.output_proj(slots) count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
return self.dropout(slots) + inp_residual slots = slots / count[..., None]
slots = self.output_proj(slots)
@ATTENTION.register_module() return self.dropout(slots) + inp_residual
class MSDeformableAttention3D(BaseModule):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection. @ATTENTION.register_module()
<https://arxiv.org/pdf/2010.04159.pdf>`_. class MSDeformableAttention3D(BaseModule):
Args: """An attention module used in BEVFormer based on Deformable-Detr.
embed_dims (int): The embedding dimension of Attention. `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Default: 256. <https://arxiv.org/pdf/2010.04159.pdf>`_.
num_heads (int): Parallel attention heads. Default: 64. Args:
num_levels (int): The number of feature map used in embed_dims (int): The embedding dimension of Attention.
Attention. Default: 4. Default: 256.
num_points (int): The number of sampling points for num_heads (int): Parallel attention heads. Default: 64.
each query in each head. Default: 4. num_levels (int): The number of feature map used in
im2col_step (int): The step used in image_to_column. Attention. Default: 4.
Default: 64. num_points (int): The number of sampling points for
dropout (float): A Dropout layer on `inp_identity`. each query in each head. Default: 4.
Default: 0.1. im2col_step (int): The step used in image_to_column.
batch_first (bool): Key, Query and Value are shape of Default: 64.
(batch, n, embed_dim) dropout (float): A Dropout layer on `inp_identity`.
or (n, batch, embed_dim). Default to False. Default: 0.1.
norm_cfg (dict): Config dict for normalization layer. batch_first (bool): Key, Query and Value are shape of
Default: None. (batch, n, embed_dim)
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. or (n, batch, embed_dim). Default to False.
Default: None. norm_cfg (dict): Config dict for normalization layer.
""" Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
def __init__(self, Default: None.
embed_dims=256, """
num_heads=8,
num_levels=4, def __init__(self,
num_points=8, embed_dims=256,
im2col_step=64, num_heads=8,
dropout=0.1, num_levels=4,
batch_first=True, num_points=8,
norm_cfg=None, im2col_step=64,
init_cfg=None): dropout=0.1,
super().__init__(init_cfg) batch_first=True,
if embed_dims % num_heads != 0: norm_cfg=None,
raise ValueError(f'embed_dims must be divisible by num_heads, ' init_cfg=None):
f'but got {embed_dims} and {num_heads}') super().__init__(init_cfg)
dim_per_head = embed_dims // num_heads if embed_dims % num_heads != 0:
self.norm_cfg = norm_cfg raise ValueError(f'embed_dims must be divisible by num_heads, '
self.batch_first = batch_first f'but got {embed_dims} and {num_heads}')
self.output_proj = None dim_per_head = embed_dims // num_heads
self.fp16_enabled = False self.norm_cfg = norm_cfg
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2 self.output_proj = None
# which is more efficient in the CUDA implementation self.fp16_enabled = False
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): # you'd better set dim_per_head to a power of 2
raise ValueError( # which is more efficient in the CUDA implementation
'invalid input for _is_power_of_2: {} (type: {})'.format( def _is_power_of_2(n):
n, type(n))) if (not isinstance(n, int)) or (n < 0):
return (n & (n - 1) == 0) and n != 0 raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
if not _is_power_of_2(dim_per_head): n, type(n)))
warnings.warn( return (n & (n - 1) == 0) and n != 0
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make ' if not _is_power_of_2(dim_per_head):
'the dimension of each attention head a power of 2 ' warnings.warn(
'which is more efficient in our CUDA implementation.') "You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
self.im2col_step = im2col_step 'the dimension of each attention head a power of 2 '
self.embed_dims = embed_dims 'which is more efficient in our CUDA implementation.')
self.num_levels = num_levels
self.num_heads = num_heads self.im2col_step = im2col_step
self.num_points = num_points self.embed_dims = embed_dims
self.sampling_offsets = nn.Linear( self.num_levels = num_levels
embed_dims, num_heads * num_levels * num_points * 2) self.num_heads = num_heads
self.attention_weights = nn.Linear(embed_dims, self.num_points = num_points
num_heads * num_levels * num_points) self.sampling_offsets = nn.Linear(
self.value_proj = nn.Linear(embed_dims, embed_dims) embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
self.init_weights() num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
def init_weights(self):
"""Default initialization for Parameters of Module.""" self.init_weights()
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange( def init_weights(self):
self.num_heads, """Default initialization for Parameters of Module."""
dtype=torch.float32) * (2.0 * math.pi / self.num_heads) constant_init(self.sampling_offsets, 0.)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) thetas = torch.arange(
grid_init = (grid_init / self.num_heads,
grid_init.abs().max(-1, keepdim=True)[0]).view( dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
self.num_heads, 1, 1, grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
2).repeat(1, self.num_levels, self.num_points, 1) grid_init = (grid_init /
for i in range(self.num_points): grid_init.abs().max(-1, keepdim=True)[0]).view(
grid_init[:, :, i, :] *= i + 1 self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
self.sampling_offsets.bias.data = grid_init.view(-1) for i in range(self.num_points):
constant_init(self.attention_weights, val=0., bias=0.) grid_init[:, :, i, :] *= i + 1
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.) self.sampling_offsets.bias.data = grid_init.view(-1)
self._is_init = True constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
def forward(self, xavier_init(self.output_proj, distribution='uniform', bias=0.)
query, self._is_init = True
key=None,
value=None, def forward(self,
identity=None, query,
query_pos=None, key=None,
key_padding_mask=None, value=None,
reference_points=None, identity=None,
spatial_shapes=None, query_pos=None,
level_start_index=None, key_padding_mask=None,
**kwargs): reference_points=None,
"""Forward Function of MultiScaleDeformAttention. spatial_shapes=None,
Args: level_start_index=None,
query (Tensor): Query of Transformer with shape **kwargs):
( bs, num_query, embed_dims). """Forward Function of MultiScaleDeformAttention.
key (Tensor): The key tensor with shape Args:
`(bs, num_key, embed_dims)`. query (Tensor): Query of Transformer with shape
value (Tensor): The value tensor with shape ( bs, num_query, embed_dims).
`(bs, num_key, embed_dims)`. key (Tensor): The key tensor with shape
identity (Tensor): The tensor used for addition, with the `(bs, num_key, embed_dims)`.
same shape as `query`. Default None. If None, value (Tensor): The value tensor with shape
`query` will be used. `(bs, num_key, embed_dims)`.
query_pos (Tensor): The positional encoding for `query`. identity (Tensor): The tensor used for addition, with the
Default: None. same shape as `query`. Default None. If None,
key_pos (Tensor): The positional encoding for `key`. Default `query` will be used.
None. query_pos (Tensor): The positional encoding for `query`.
reference_points (Tensor): The normalized reference Default: None.
points with shape (bs, num_query, num_levels, 2), key_pos (Tensor): The positional encoding for `key`. Default
all elements is range in [0, 1], top-left (0,0), None.
bottom-right (1, 1), including padding area. reference_points (Tensor): The normalized reference
or (N, Length_{query}, num_levels, 4), add points with shape (bs, num_query, num_levels, 2),
additional two dimensions is (w, h) to all elements is range in [0, 1], top-left (0,0),
form reference boxes. bottom-right (1, 1), including padding area.
key_padding_mask (Tensor): ByteTensor for `query`, with or (N, Length_{query}, num_levels, 4), add
shape [bs, num_key]. additional two dimensions is (w, h) to
spatial_shapes (Tensor): Spatial shape of features in form reference boxes.
different levels. With shape (num_levels, 2), key_padding_mask (Tensor): ByteTensor for `query`, with
last dimension represents (h, w). shape [bs, num_key].
level_start_index (Tensor): The start index of each level. spatial_shapes (Tensor): Spatial shape of features in
A tensor has shape ``(num_levels, )`` and can be represented different levels. With shape (num_levels, 2),
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. last dimension represents (h, w).
Returns: level_start_index (Tensor): The start index of each level.
Tensor: forwarded results with shape [num_query, bs, embed_dims]. A tensor has shape ``(num_levels, )`` and can be represented
""" as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
if value is None: Tensor: forwarded results with shape [num_query, bs, embed_dims].
value = query """
if identity is None:
identity = query if value is None:
if query_pos is not None: value = query
query = query + query_pos if identity is None:
identity = query
if not self.batch_first: if query_pos is not None:
# change to (bs, num_query ,embed_dims) query = query + query_pos
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2) if not self.batch_first:
# change to (bs, num_query ,embed_dims)
bs, num_query, _ = query.shape query = query.permute(1, 0, 2)
bs, num_value, _ = value.shape value = value.permute(1, 0, 2)
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
bs, num_query, _ = query.shape
value = self.value_proj(value) bs, num_value, _ = value.shape
if key_padding_mask is not None: assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1) value = self.value_proj(value)
sampling_offsets = self.sampling_offsets(query).view( if key_padding_mask is not None:
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) value = value.masked_fill(key_padding_mask[..., None], 0.0)
attention_weights = self.attention_weights(query).view( value = value.view(bs, num_value, self.num_heads, -1)
bs, num_query, self.num_heads, self.num_levels * self.num_points) sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = attention_weights.softmax(-1) attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads, attention_weights = attention_weights.softmax(-1)
self.num_levels,
self.num_points) attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
if reference_points.shape[-1] == 2: self.num_levels,
""" self.num_points)
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image. if reference_points.shape[-1] == 2:
For each referent point, we sample `num_points` sampling points. """
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points. For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
""" After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
offset_normalizer = torch.stack( For each referent point, we sample `num_points` sampling points.
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
bs, num_query, num_Z_anchors, xy = reference_points.shape offset_normalizer = torch.stack(
reference_points = reference_points[:, :, None, None, None, :, :] [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_offsets = sampling_offsets / \
offset_normalizer[None, None, None, :, None, :] bs, num_query, num_Z_anchors, xy = reference_points.shape
bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_offsets.shape reference_points = reference_points[:, :, None, None, None, :, :]
sampling_offsets = sampling_offsets.view( sampling_offsets = sampling_offsets / \
bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy) offset_normalizer[None, None, None, :, None, :]
sampling_locations = reference_points + sampling_offsets bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_offsets.shape
bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape sampling_offsets = sampling_offsets.view(
assert num_all_points == num_points * num_Z_anchors bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy)
sampling_locations = reference_points + sampling_offsets
sampling_locations = sampling_locations.view( bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape
bs, num_query, num_heads, num_levels, num_all_points, xy) assert num_all_points == num_points * num_Z_anchors
elif reference_points.shape[-1] == 4: sampling_locations = sampling_locations.view(
assert False bs, num_query, num_heads, num_levels, num_all_points, xy)
else:
raise ValueError( elif reference_points.shape[-1] == 4:
f'Last dim of reference_points must be' assert False
f' 2 or 4, but get {reference_points.shape[-1]} instead.') else:
raise ValueError(
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2 f'Last dim of reference_points must be'
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points f' 2 or 4, but get {reference_points.shape[-1]} instead.')
#
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
if torch.cuda.is_available() and value.is_cuda: # attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
if value.dtype == torch.float16: #
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else: if torch.cuda.is_available() and value.is_cuda:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 if value.dtype == torch.float16:
output = MultiScaleDeformableAttnFunction.apply( MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
value, spatial_shapes, level_start_index, sampling_locations, else:
attention_weights, self.im2col_step) MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else: output = MultiScaleDeformableAttnFunction.apply(
output = multi_scale_deformable_attn_pytorch( value, spatial_shapes, level_start_index, sampling_locations,
value, spatial_shapes, sampling_locations, attention_weights) attention_weights, self.im2col_step)
if not self.batch_first: else:
output = output.permute(1, 0, 2) output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
return output if not self.batch_first:
output = output.permute(1, 0, 2)
return output
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32 import math
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch import warnings
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
import math from mmcv.ops.multi_scale_deform_attn import \
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential multi_scale_deformable_attn_pytorch
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, from mmcv.runner.base_module import BaseModule
to_2tuple) from mmcv.utils import ext_loader
from mmcv.utils import ext_loader from .multi_scale_deformable_attn_function import \
ext_module = ext_loader.load_ext( MultiScaleDeformableAttnFunction_fp32
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@ATTENTION.register_module()
class TemporalSelfAttention(BaseModule):
"""An attention module used in BEVFormer based on Deformable-Detr. @ATTENTION.register_module()
class TemporalSelfAttention(BaseModule):
`Deformable DETR: Deformable Transformers for End-to-End Object Detection. """An attention module used in BEVFormer based on Deformable-Detr.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Args: <https://arxiv.org/pdf/2010.04159.pdf>`_.
embed_dims (int): The embedding dimension of Attention.
Default: 256. Args:
num_heads (int): Parallel attention heads. Default: 64. embed_dims (int): The embedding dimension of Attention.
num_levels (int): The number of feature map used in Default: 256.
Attention. Default: 4. num_heads (int): Parallel attention heads. Default: 64.
num_points (int): The number of sampling points for num_levels (int): The number of feature map used in
each query in each head. Default: 4. Attention. Default: 4.
im2col_step (int): The step used in image_to_column. num_points (int): The number of sampling points for
Default: 64. each query in each head. Default: 4.
dropout (float): A Dropout layer on `inp_identity`. im2col_step (int): The step used in image_to_column.
Default: 0.1. Default: 64.
batch_first (bool): Key, Query and Value are shape of dropout (float): A Dropout layer on `inp_identity`.
(batch, n, embed_dim) Default: 0.1.
or (n, batch, embed_dim). Default to True. batch_first (bool): Key, Query and Value are shape of
norm_cfg (dict): Config dict for normalization layer. (batch, n, embed_dim)
Default: None. or (n, batch, embed_dim). Default to True.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. norm_cfg (dict): Config dict for normalization layer.
Default: None. Default: None.
num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
the length of BEV queue is 2. Default: None.
""" num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV.
the length of BEV queue is 2.
def __init__(self, """
embed_dims=256,
num_heads=8, def __init__(self,
num_levels=4, embed_dims=256,
num_points=4, num_heads=8,
num_bev_queue=2, num_levels=4,
im2col_step=64, num_points=4,
dropout=0.1, num_bev_queue=2,
batch_first=True, im2col_step=64,
norm_cfg=None, dropout=0.1,
init_cfg=None): batch_first=True,
norm_cfg=None,
super().__init__(init_cfg) init_cfg=None):
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, ' super().__init__(init_cfg)
f'but got {embed_dims} and {num_heads}') if embed_dims % num_heads != 0:
dim_per_head = embed_dims // num_heads raise ValueError(f'embed_dims must be divisible by num_heads, '
self.norm_cfg = norm_cfg f'but got {embed_dims} and {num_heads}')
self.dropout = nn.Dropout(dropout) dim_per_head = embed_dims // num_heads
self.batch_first = batch_first self.norm_cfg = norm_cfg
self.fp16_enabled = False self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2 self.fp16_enabled = False
# which is more efficient in the CUDA implementation
def _is_power_of_2(n): # you'd better set dim_per_head to a power of 2
if (not isinstance(n, int)) or (n < 0): # which is more efficient in the CUDA implementation
raise ValueError( def _is_power_of_2(n):
'invalid input for _is_power_of_2: {} (type: {})'.format( if (not isinstance(n, int)) or (n < 0):
n, type(n))) raise ValueError(
return (n & (n - 1) == 0) and n != 0 'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
if not _is_power_of_2(dim_per_head): return (n & (n - 1) == 0) and n != 0
warnings.warn(
"You'd better set embed_dims in " if not _is_power_of_2(dim_per_head):
'MultiScaleDeformAttention to make ' warnings.warn(
'the dimension of each attention head a power of 2 ' "You'd better set embed_dims in "
'which is more efficient in our CUDA implementation.') 'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
self.im2col_step = im2col_step 'which is more efficient in our CUDA implementation.')
self.embed_dims = embed_dims
self.num_levels = num_levels self.im2col_step = im2col_step
self.num_heads = num_heads self.embed_dims = embed_dims
self.num_points = num_points self.num_levels = num_levels
self.num_bev_queue = num_bev_queue self.num_heads = num_heads
self.sampling_offsets = nn.Linear( self.num_points = num_points
embed_dims*self.num_bev_queue, num_bev_queue*num_heads * num_levels * num_points * 2) self.num_bev_queue = num_bev_queue
self.attention_weights = nn.Linear(embed_dims*self.num_bev_queue, self.sampling_offsets = nn.Linear(
num_bev_queue*num_heads * num_levels * num_points) embed_dims * self.num_bev_queue, num_bev_queue * num_heads * num_levels * num_points * 2)
self.value_proj = nn.Linear(embed_dims, embed_dims) self.attention_weights = nn.Linear(embed_dims * self.num_bev_queue,
self.output_proj = nn.Linear(embed_dims, embed_dims) num_bev_queue * num_heads * num_levels * num_points)
self.init_weights() self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
def init_weights(self): self.init_weights()
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.) def init_weights(self):
thetas = torch.arange( """Default initialization for Parameters of Module."""
self.num_heads, constant_init(self.sampling_offsets, 0.)
dtype=torch.float32) * (2.0 * math.pi / self.num_heads) thetas = torch.arange(
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) self.num_heads,
grid_init = (grid_init / dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init.abs().max(-1, keepdim=True)[0]).view( grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
self.num_heads, 1, 1, grid_init = (grid_init /
2).repeat(1, self.num_levels*self.num_bev_queue, self.num_points, 1) grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
for i in range(self.num_points): 2).repeat(1, self.num_levels * self.num_bev_queue, self.num_points, 1)
grid_init[:, :, i, :] *= i + 1
for i in range(self.num_points):
self.sampling_offsets.bias.data = grid_init.view(-1) grid_init[:, :, i, :] *= i + 1
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.) self.sampling_offsets.bias.data = grid_init.view(-1)
xavier_init(self.output_proj, distribution='uniform', bias=0.) constant_init(self.attention_weights, val=0., bias=0.)
self._is_init = True xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self, self._is_init = True
query,
key=None, def forward(self,
value=None, query,
identity=None, key=None,
query_pos=None, value=None,
key_padding_mask=None, identity=None,
reference_points=None, query_pos=None,
spatial_shapes=None, key_padding_mask=None,
level_start_index=None, reference_points=None,
flag='decoder', spatial_shapes=None,
level_start_index=None,
**kwargs): flag='decoder',
"""Forward Function of MultiScaleDeformAttention.
**kwargs):
Args: """Forward Function of MultiScaleDeformAttention.
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims). Args:
key (Tensor): The key tensor with shape query (Tensor): Query of Transformer with shape
`(num_key, bs, embed_dims)`. (num_query, bs, embed_dims).
value (Tensor): The value tensor with shape key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`. `(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the value (Tensor): The value tensor with shape
same shape as `query`. Default None. If None, `(num_key, bs, embed_dims)`.
`query` will be used. identity (Tensor): The tensor used for addition, with the
query_pos (Tensor): The positional encoding for `query`. same shape as `query`. Default None. If None,
Default: None. `query` will be used.
key_pos (Tensor): The positional encoding for `key`. Default query_pos (Tensor): The positional encoding for `query`.
None. Default: None.
reference_points (Tensor): The normalized reference key_pos (Tensor): The positional encoding for `key`. Default
points with shape (bs, num_query, num_levels, 2), None.
all elements is range in [0, 1], top-left (0,0), reference_points (Tensor): The normalized reference
bottom-right (1, 1), including padding area. points with shape (bs, num_query, num_levels, 2),
or (N, Length_{query}, num_levels, 4), add all elements is range in [0, 1], top-left (0,0),
additional two dimensions is (w, h) to bottom-right (1, 1), including padding area.
form reference boxes. or (N, Length_{query}, num_levels, 4), add
key_padding_mask (Tensor): ByteTensor for `query`, with additional two dimensions is (w, h) to
shape [bs, num_key]. form reference boxes.
spatial_shapes (Tensor): Spatial shape of features in key_padding_mask (Tensor): ByteTensor for `query`, with
different levels. With shape (num_levels, 2), shape [bs, num_key].
last dimension represents (h, w). spatial_shapes (Tensor): Spatial shape of features in
level_start_index (Tensor): The start index of each level. different levels. With shape (num_levels, 2),
A tensor has shape ``(num_levels, )`` and can be represented last dimension represents (h, w).
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
Returns: as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Tensor: forwarded results with shape [num_query, bs, embed_dims].
""" Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
if value is None: """
assert self.batch_first
bs, len_bev, c = query.shape if value is None:
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c) assert self.batch_first
bs, len_bev, c = query.shape
# value = torch.cat([query, query], 0) value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c)
if identity is None: # value = torch.cat([query, query], 0)
identity = query
if query_pos is not None: if identity is None:
query = query + query_pos identity = query
if not self.batch_first: if query_pos is not None:
# change to (bs, num_query ,embed_dims) query = query + query_pos
query = query.permute(1, 0, 2) if not self.batch_first:
value = value.permute(1, 0, 2) # change to (bs, num_query ,embed_dims)
bs, num_query, embed_dims = query.shape query = query.permute(1, 0, 2)
_, num_value, _ = value.shape value = value.permute(1, 0, 2)
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value bs, num_query, embed_dims = query.shape
assert self.num_bev_queue == 2 _, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
query = torch.cat([value[::2], query], -1) assert self.num_bev_queue == 2
value_ = value.clone()
value_[:bs] = value[::2] query = torch.cat([value[::2], query], -1)
value_[bs:] = value[1::2] value_ = value.clone()
value = self.value_proj(value) value_[:bs] = value[::2]
value = self.value_proj(value) value_[bs:] = value[1::2]
value = self.value_proj(value)
if key_padding_mask is not None: value = self.value_proj(value)
value = value.masked_fill(key_padding_mask[..., None], 0.0)
if key_padding_mask is not None:
value = value.reshape(bs*self.num_bev_queue, value = value.masked_fill(key_padding_mask[..., None], 0.0)
num_value, self.num_heads, -1)
value = value.reshape(bs * self.num_bev_queue,
sampling_offsets = self.sampling_offsets(query) num_value, self.num_heads, -1)
sampling_offsets = sampling_offsets.view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2) sampling_offsets = self.sampling_offsets(query)
attention_weights = self.attention_weights(query).view( sampling_offsets = sampling_offsets.view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points) bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2)
attention_weights = attention_weights.softmax(-1) attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
attention_weights = attention_weights.view(bs, num_query, attention_weights = attention_weights.softmax(-1)
self.num_heads,
self.num_bev_queue, attention_weights = attention_weights.view(bs, num_query,
self.num_levels, self.num_heads,
self.num_points) self.num_bev_queue,
self.num_levels,
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\ self.num_points)
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\ attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5) \
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2) .reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6) \
if reference_points.shape[-1] == 2: .reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) if reference_points.shape[-1] == 2:
sampling_locations = reference_points[:, :, None, :, None, :] \ offset_normalizer = torch.stack(
+ sampling_offsets \ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
/ offset_normalizer[None, None, None, :, None, :] sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
elif reference_points.shape[-1] == 4: / offset_normalizer[None, None, None, :, None, :]
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \ elif reference_points.shape[-1] == 4:
* reference_points[:, :, None, :, None, 2:] \ sampling_locations = reference_points[:, :, None, :, None, :2] \
* 0.5 + sampling_offsets / self.num_points \
else: * reference_points[:, :, None, :, None, 2:] \
raise ValueError( * 0.5
f'Last dim of reference_points must be' else:
f' 2 or 4, but get {reference_points.shape[-1]} instead.') raise ValueError(
if torch.cuda.is_available() and value.is_cuda: f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
# using fp16 deformable attention is unstable because it performs many sum operations if torch.cuda.is_available() and value.is_cuda:
if value.dtype == torch.float16:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 # using fp16 deformable attention is unstable because it performs many sum operations
else: if value.dtype == torch.float16:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply( else:
value, spatial_shapes, level_start_index, sampling_locations, MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
attention_weights, self.im2col_step) output = MultiScaleDeformableAttnFunction.apply(
else: value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = multi_scale_deformable_attn_pytorch( else:
value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attn_pytorch(
# output shape (bs*num_bev_queue, num_query, embed_dims) value, spatial_shapes, sampling_locations, attention_weights)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
output = output.permute(1, 2, 0) # output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
# fuse history value and current value output = output.permute(1, 2, 0)
# (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output = output.view(num_query, embed_dims, bs, self.num_bev_queue) # fuse history value and current value
output = output.mean(-1) # (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims) output = output.mean(-1)
output = output.permute(2, 0, 1)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
output = self.output_proj(output) output = output.permute(2, 0, 1)
if not self.batch_first: output = self.output_proj(output)
output = output.permute(1, 0, 2)
if not self.batch_first:
return self.dropout(output) + identity output = output.permute(1, 0, 2)
return self.dropout(output) + identity
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
import numpy as np import torch
import torch import torch.nn as nn
import torch.nn as nn from mmcv.cnn import xavier_init
from mmcv.cnn import xavier_init from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence from mmcv.runner import auto_fp16
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_ from .decoder import CustomMSDeformableAttention
from mmcv.runner.base_module import BaseModule from .spatial_cross_attention import MSDeformableAttention3D
from torchvision.transforms.functional import rotate from .temporal_self_attention import TemporalSelfAttention
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention @TRANSFORMER.register_module()
from mmcv.runner import force_fp32, auto_fp16 class PerceptionTransformer(BaseModule):
import pdb """Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
@TRANSFORMER.register_module() Default: False.
class PerceptionTransformer(BaseModule): num_feature_levels (int): Number of feature maps from FPN:
"""Implements the Detr3D transformer. Default: 4.
Args: two_stage_num_proposals (int): Number of proposals when set
as_two_stage (bool): Generate query from encoder features. `as_two_stage` as True. Default: 300.
Default: False. """
num_feature_levels (int): Number of feature maps from FPN:
Default: 4. def __init__(self,
two_stage_num_proposals (int): Number of proposals when set decoder=None,
`as_two_stage` as True. Default: 300. embed_dims=256,
""" **kwargs):
super(PerceptionTransformer, self).__init__(**kwargs)
def __init__(self, self.decoder = build_transformer_layer_sequence(decoder)
decoder=None, self.embed_dims = embed_dims
embed_dims=256, self.fp16_enabled = False
**kwargs): self.init_layers()
super(PerceptionTransformer, self).__init__(**kwargs)
self.decoder = build_transformer_layer_sequence(decoder) def init_layers(self):
self.embed_dims = embed_dims """Initialize layers of the Detr3DTransformer."""
self.fp16_enabled = False self.reference_points = nn.Linear(self.embed_dims, 3)
self.init_layers()
def init_weights(self):
def init_layers(self): """Initialize the transformer weights."""
"""Initialize layers of the Detr3DTransformer.""" for p in self.parameters():
self.reference_points = nn.Linear(self.embed_dims, 3) if p.dim() > 1:
nn.init.xavier_uniform_(p)
def init_weights(self): for m in self.modules():
"""Initialize the transformer weights.""" if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
for p in self.parameters(): or isinstance(m, CustomMSDeformableAttention):
if p.dim() > 1: try:
nn.init.xavier_uniform_(p) m.init_weight()
for m in self.modules(): except AttributeError:
if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \ m.init_weights()
or isinstance(m, CustomMSDeformableAttention): xavier_init(self.reference_points, distribution='uniform', bias=0.)
try:
m.init_weight() @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos'))
except AttributeError: def forward(self,
m.init_weights() mlvl_feats,
xavier_init(self.reference_points, distribution='uniform', bias=0.) bev_embed,
object_query_embed,
bev_h,
@auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos')) bev_w,
def forward(self, reg_branches=None,
mlvl_feats, cls_branches=None,
bev_embed, **kwargs):
object_query_embed, """Forward function for `Detr3DTransformer`.
bev_h, Args:
bev_w, mlvl_feats (list(Tensor)): Input queries from
reg_branches=None, different level. Each element has shape
cls_branches=None, [bs, num_cams, embed_dims, h, w].
**kwargs): bev_queries (Tensor): (bev_h*bev_w, c)
"""Forward function for `Detr3DTransformer`. bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
Args: object_query_embed (Tensor): The query embedding for decoder,
mlvl_feats (list(Tensor)): Input queries from with shape [num_query, c].
different level. Each element has shape reg_branches (obj:`nn.ModuleList`): Regression heads for
[bs, num_cams, embed_dims, h, w]. feature maps from each decoder layer. Only would
bev_queries (Tensor): (bev_h*bev_w, c) be passed when `with_box_refine` is True. Default to None.
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w) Returns:
object_query_embed (Tensor): The query embedding for decoder, tuple[Tensor]: results of decoder containing the following tensor.
with shape [num_query, c]. - bev_embed: BEV features
reg_branches (obj:`nn.ModuleList`): Regression heads for - inter_states: Outputs from decoder. If
feature maps from each decoder layer. Only would return_intermediate_dec is True output has shape \
be passed when `with_box_refine` is True. Default to None. (num_dec_layers, bs, num_query, embed_dims), else has \
Returns: shape (1, bs, num_query, embed_dims).
tuple[Tensor]: results of decoder containing the following tensor. - init_reference_out: The initial value of reference \
- bev_embed: BEV features points, has shape (bs, num_queries, 4).
- inter_states: Outputs from decoder. If - inter_references_out: The internal value of reference \
return_intermediate_dec is True output has shape \ points in decoder, has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \ (num_dec_layers, bs,num_query, embed_dims)
shape (1, bs, num_query, embed_dims). - enc_outputs_class: The classification score of \
- init_reference_out: The initial value of reference \ proposals generated from \
points, has shape (bs, num_queries, 4). encoder's feature maps, has shape \
- inter_references_out: The internal value of reference \ (batch, h*w, num_classes). \
points in decoder, has shape \ Only would be returned when `as_two_stage` is True, \
(num_dec_layers, bs,num_query, embed_dims) otherwise None.
- enc_outputs_class: The classification score of \ - enc_outputs_coord_unact: The regression results \
proposals generated from \ generated from encoder's feature maps., has shape \
encoder's feature maps, has shape \ (batch, h*w, 4). Only would \
(batch, h*w, num_classes). \ be returned when `as_two_stage` is True, \
Only would be returned when `as_two_stage` is True, \ otherwise None.
otherwise None. """
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \ bs = mlvl_feats[0].size(0)
(batch, h*w, 4). Only would \ query_pos, query = torch.split(
be returned when `as_two_stage` is True, \ object_query_embed, self.embed_dims, dim=1)
otherwise None. query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
""" query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos)
bs = mlvl_feats[0].size(0) reference_points = reference_points.sigmoid()
query_pos, query = torch.split( init_reference_out = reference_points
object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.permute(1, 0, 2)
query = query.unsqueeze(0).expand(bs, -1, -1) query_pos = query_pos.permute(1, 0, 2)
reference_points = self.reference_points(query_pos) bev_embed = bev_embed.permute(1, 0, 2)
reference_points = reference_points.sigmoid() inter_states, inter_references = self.decoder(
init_reference_out = reference_points query=query,
key=None,
query = query.permute(1, 0, 2) value=bev_embed,
query_pos = query_pos.permute(1, 0, 2) query_pos=query_pos,
bev_embed = bev_embed.permute(1, 0, 2) reference_points=reference_points,
inter_states, inter_references = self.decoder( reg_branches=reg_branches,
query=query, cls_branches=cls_branches,
key=None, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
value=bev_embed, level_start_index=torch.tensor([0], device=query.device),
query_pos=query_pos, **kwargs)
reference_points=reference_points,
reg_branches=reg_branches, inter_references_out = inter_references
cls_branches=cls_branches,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), return inter_states, init_reference_out, inter_references_out
level_start_index=torch.tensor([0], device=query.device),
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
from .custom_fpn import *
from .custom_ipm_view_transformer import *
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# custom_fpn.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # custom_fpn.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmdet3d.models import NECKS from mmdet3d.models import NECKS
...@@ -34,7 +33,7 @@ class CustomFPN(BaseModule): ...@@ -34,7 +33,7 @@ class CustomFPN(BaseModule):
Notes Notes
----- -----
Adapted from https://github.com/HuangJunJie2017/BEVDet/blob/dev2.0/mmdet3d/models/necks/fpn.py#L11. Adapted from https://github.com/HuangJunJie2017/BEVDet/blob/dev2.0/mmdet3d/models/necks/fpn.py#L11.
Feature Pyramid Network. Feature Pyramid Network.
This is an implementation of paper `Feature Pyramid Networks for Object This is an implementation of paper `Feature Pyramid Networks for Object
Detection <https://arxiv.org/abs/1612.03144>`_. Detection <https://arxiv.org/abs/1612.03144>`_.
......
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# custom_ipm_view_transformer.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # custom_ipm_view_transformer.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -20,12 +20,9 @@ ...@@ -20,12 +20,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import copy
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmdet3d.models import NECKS from mmdet3d.models import NECKS
...@@ -33,7 +30,7 @@ from mmdet3d.models import NECKS ...@@ -33,7 +30,7 @@ from mmdet3d.models import NECKS
def get_campos(reference_points, ego2cam, img_shape): def get_campos(reference_points, ego2cam, img_shape):
''' '''
Find the each refence point's corresponding pixel in each camera Find the each refence point's corresponding pixel in each camera
Args: Args:
reference_points: [B, num_query, 3] reference_points: [B, num_query, 3]
ego2cam: (B, num_cam, 4, 4) ego2cam: (B, num_cam, 4, 4)
Outs: Outs:
...@@ -63,7 +60,7 @@ def get_campos(reference_points, ego2cam, img_shape): ...@@ -63,7 +60,7 @@ def get_campos(reference_points, ego2cam, img_shape):
eps = 1e-9 eps = 1e-9
mask = (reference_points_cam[..., 2:3] > eps) mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam =\ reference_points_cam = \
reference_points_cam[..., 0:2] / \ reference_points_cam[..., 0:2] / \
reference_points_cam[..., 2:3] + eps reference_points_cam[..., 2:3] + eps
...@@ -74,16 +71,17 @@ def get_campos(reference_points, ego2cam, img_shape): ...@@ -74,16 +71,17 @@ def get_campos(reference_points, ego2cam, img_shape):
reference_points_cam = (reference_points_cam - 0.5) * 2 reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0) mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0) & (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0)) & (reference_points_cam[..., 1:2] < 1.0))
# (B, num_cam, num_query) # (B, num_cam, num_query)
mask = mask.view(B, num_cam, num_query) mask = mask.view(B, num_cam, num_query)
reference_points_cam = reference_points_cam.view(B*num_cam, num_query, 2) reference_points_cam = reference_points_cam.view(B * num_cam, num_query, 2)
return reference_points_cam, mask return reference_points_cam, mask
def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32): def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32):
''' '''
Returns: Returns:
...@@ -108,6 +106,7 @@ def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32): ...@@ -108,6 +106,7 @@ def construct_plane_grid(xbound, ybound, height: float, dtype=torch.float32):
return plane return plane
@NECKS.register_module() @NECKS.register_module()
class CustomIPMViewTransformer(BaseModule): class CustomIPMViewTransformer(BaseModule):
r""" r"""
...@@ -116,8 +115,9 @@ class CustomIPMViewTransformer(BaseModule): ...@@ -116,8 +115,9 @@ class CustomIPMViewTransformer(BaseModule):
Adapted from https://github.com/Mrmoore98/VectorMapNet_code/blob/mian/plugin/models/backbones/ipm_backbone.py#L238. Adapted from https://github.com/Mrmoore98/VectorMapNet_code/blob/mian/plugin/models/backbones/ipm_backbone.py#L238.
""" """
def __init__(self,
num_cam, def __init__(self,
num_cam,
xbound, xbound,
ybound, ybound,
zbound, zbound,
...@@ -126,24 +126,24 @@ class CustomIPMViewTransformer(BaseModule): ...@@ -126,24 +126,24 @@ class CustomIPMViewTransformer(BaseModule):
super().__init__() super().__init__()
self.x_bound = xbound self.x_bound = xbound
self.y_bound = ybound self.y_bound = ybound
heights = [zbound[0]+i*zbound[2] for i in range(int((zbound[1]-zbound[0])//zbound[2])+1)] heights = [zbound[0] + i * zbound[2] for i in range(int((zbound[1] - zbound[0]) // zbound[2]) + 1)]
self.heights = heights self.heights = heights
self.num_cam = num_cam self.num_cam = num_cam
self.outconvs =\ self.outconvs = \
nn.Conv2d((out_channels+3)*len(heights), out_channels, nn.Conv2d((out_channels + 3) * len(heights), out_channels,
kernel_size=3, stride=1, padding=1) # same kernel_size=3, stride=1, padding=1) # same
# bev_plane # bev_plane
bev_planes = [construct_plane_grid( bev_planes = [construct_plane_grid(
xbound, ybound, h) for h in self.heights] xbound, ybound, h) for h in self.heights]
self.register_buffer('bev_planes', torch.stack( self.register_buffer('bev_planes', torch.stack(
bev_planes),) # nlvl,bH,bW,2 bev_planes), ) # nlvl,bH,bW,2
def forward(self, cam_feat, ego2cam, img_shape): def forward(self, cam_feat, ego2cam, img_shape):
''' '''
inverse project inverse project
Args: Args:
cam_feat: B*ncam, C, cH, cW cam_feat: B*ncam, C, cH, cW
img_shape: tuple(H, W) img_shape: tuple(H, W)
...@@ -161,7 +161,7 @@ class CustomIPMViewTransformer(BaseModule): ...@@ -161,7 +161,7 @@ class CustomIPMViewTransformer(BaseModule):
# bev_grid_pos: B*ncam, nlvl*bH*bW, 2 # bev_grid_pos: B*ncam, nlvl*bH*bW, 2
bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape) bev_grid_pos, bev_cam_mask = get_campos(bev_grid, ego2cam, img_shape)
# B*cam, nlvl*bH, bW, 2 # B*cam, nlvl*bH, bW, 2
bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl*bH, bW)) bev_grid_pos = bev_grid_pos.unflatten(-2, (nlvl * bH, bW))
# project feat from 2D to bev plane # project feat from 2D to bev plane
projected_feature = F.grid_sample( projected_feature = F.grid_sample(
...@@ -173,11 +173,11 @@ class CustomIPMViewTransformer(BaseModule): ...@@ -173,11 +173,11 @@ class CustomIPMViewTransformer(BaseModule):
# eliminate the ncam # eliminate the ncam
# The bev feature is the sum of the 6 cameras # The bev feature is the sum of the 6 cameras
bev_feat_mask = bev_feat_mask.unsqueeze(2) bev_feat_mask = bev_feat_mask.unsqueeze(2)
projected_feature = (projected_feature*bev_feat_mask).sum(1) projected_feature = (projected_feature * bev_feat_mask).sum(1)
num_feat = bev_feat_mask.sum(1) num_feat = bev_feat_mask.sum(1)
projected_feature = projected_feature / \ projected_feature = projected_feature / \
num_feat.masked_fill(num_feat == 0, 1) num_feat.masked_fill(num_feat == 0, 1)
# concatenate a position information # concatenate a position information
# projected_feature: B, bH, bW, nlvl, C+3 # projected_feature: B, bH, bW, nlvl, C+3
......
custom_imports = dict(imports=['projects.openlanev2.baseline']) custom_imports = dict(imports=['projects.openlanev2.baseline'])
method_para = dict(n_control=5) # #point for each curve method_para = dict(n_control=5) # #point for each curve
_dim_ = 128 _dim_ = 128
...@@ -19,26 +19,26 @@ model = dict( ...@@ -19,26 +19,26 @@ model = dict(
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')), init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
img_neck=dict( img_neck=dict(
type='CustomFPN', type='CustomFPN',
in_channels=[_dim_*2, _dim_*4], in_channels=[_dim_ * 2, _dim_ * 4],
out_channels=_dim_, out_channels=_dim_,
num_outs=1, num_outs=1,
start_level=0, start_level=0,
out_ids=[0]), out_ids=[0]),
img_view_transformer=dict( img_view_transformer=dict(
type='CustomIPMViewTransformer', type='CustomIPMViewTransformer',
num_cam=7, num_cam=7,
xbound=[-50.0, 50.0, 1.0], xbound=[-50.0, 50.0, 1.0],
ybound=[-25.0, 25.0, 1.0], ybound=[-25.0, 25.0, 1.0],
zbound=[-3.0, 2.0, 0.5], zbound=[-3.0, 2.0, 0.5],
out_channels=_dim_), out_channels=_dim_),
lc_head=dict( lc_head=dict(
type='CustomDETRHead', type='CustomDETRHead',
num_classes=1, num_classes=1,
in_channels=_dim_, in_channels=_dim_,
num_query=50, num_query=50,
object_type='lane', object_type='lane',
num_layers=1, num_layers=1,
num_reg_dim=method_para['n_control']*3, num_reg_dim=method_para['n_control'] * 3,
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='FocalLoss',
use_sigmoid=True, use_sigmoid=True,
...@@ -46,17 +46,17 @@ model = dict( ...@@ -46,17 +46,17 @@ model = dict(
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=2.5), loss_bbox=dict(type='L1Loss', loss_weight=2.5),
loss_iou=dict(type='GIoULoss', loss_weight=0.0), # dummy loss_iou=dict(type='GIoULoss', loss_weight=0.0), # dummy
train_cfg=dict( train_cfg=dict(
assigner=dict( assigner=dict(
type='LaneHungarianAssigner', type='LaneHungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=1.0), cls_cost=dict(type='FocalLossCost', weight=1.0),
reg_cost=dict(type='LaneL1Cost', weight=2.5), reg_cost=dict(type='LaneL1Cost', weight=2.5),
iou_cost=dict(type='IoUCost', weight=0.0))), # dummy iou_cost=dict(type='IoUCost', weight=0.0))), # dummy
bev_range=[-50.0, -25.0, -3.0, 50.0, 25.0, 2.0]), bev_range=[-50.0, -25.0, -3.0, 50.0, 25.0, 2.0]),
te_head=dict( te_head=dict(
type='CustomDETRHead', type='CustomDETRHead',
num_classes=13, num_classes=13,
in_channels=_dim_, in_channels=_dim_,
num_query=30, num_query=30,
object_type='bbox', object_type='bbox',
...@@ -120,7 +120,7 @@ train_pipeline = [ ...@@ -120,7 +120,7 @@ train_pipeline = [
'gt_topology_lclc', 'gt_topology_lcte', 'gt_topology_lclc', 'gt_topology_lcte',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
...@@ -138,7 +138,7 @@ test_pipeline = [ ...@@ -138,7 +138,7 @@ test_pipeline = [
'img', 'img',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
......
...@@ -18,14 +18,13 @@ input_modality = dict( ...@@ -18,14 +18,13 @@ input_modality = dict(
use_external=False) use_external=False)
num_cams = 7 num_cams = 7
Map_size = [(-50, 50), (-25, 25)] Map_size = [(-50, 50), (-25, 25)]
method_para = dict(n_control=5) # #point for each curve method_para = dict(n_control=5) # #point for each curve
code_size = 3 * method_para['n_control'] code_size = 3 * method_para['n_control']
_dim_ = 256 _dim_ = 256
_pos_dim_ = _dim_//2 _pos_dim_ = _dim_ // 2
_ffn_dim_ = _dim_*2 _ffn_dim_ = _dim_ * 2
_ffn_cfg_ = dict( _ffn_cfg_ = dict(
type='FFN', type='FFN',
embed_dims=_dim_, embed_dims=_dim_,
...@@ -71,7 +70,7 @@ model = dict( ...@@ -71,7 +70,7 @@ model = dict(
pc_range=point_cloud_range, pc_range=point_cloud_range,
bev_h=bev_h_, bev_h=bev_h_,
bev_w=bev_w_, bev_w=bev_w_,
rotate_center=[bev_h_//2, bev_w_//2], rotate_center=[bev_h_ // 2, bev_w_ // 2],
encoder=dict( encoder=dict(
type='BEVFormerEncoder', type='BEVFormerEncoder',
num_layers=3, num_layers=3,
...@@ -99,7 +98,7 @@ model = dict( ...@@ -99,7 +98,7 @@ model = dict(
], ],
ffn_cfgs=_ffn_cfg_, ffn_cfgs=_ffn_cfg_,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))), 'ffn', 'norm'))),
positional_encoding=dict( positional_encoding=dict(
type='LearnedPositionalEncoding', type='LearnedPositionalEncoding',
num_feats=_pos_dim_, num_feats=_pos_dim_,
...@@ -169,7 +168,7 @@ model = dict( ...@@ -169,7 +168,7 @@ model = dict(
with_box_refine=False, with_box_refine=False,
with_shared_param=False, with_shared_param=False,
code_size=code_size, code_size=code_size,
code_weights= [1.0 for i in range(code_size)], code_weights=[1.0 for i in range(code_size)],
pc_range=point_cloud_range, pc_range=point_cloud_range,
transformer=dict( transformer=dict(
type='PerceptionTransformer', type='PerceptionTransformer',
...@@ -186,7 +185,7 @@ model = dict( ...@@ -186,7 +185,7 @@ model = dict(
embed_dims=_dim_, embed_dims=_dim_,
num_heads=8, num_heads=8,
dropout=0.1), dropout=0.1),
dict( dict(
type='CustomMSDeformableAttention', type='CustomMSDeformableAttention',
embed_dims=_dim_, embed_dims=_dim_,
num_levels=1), num_levels=1),
...@@ -240,9 +239,8 @@ model = dict( ...@@ -240,9 +239,8 @@ model = dict(
type='LaneHungarianAssigner', type='LaneHungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=1.5), cls_cost=dict(type='FocalLossCost', weight=1.5),
reg_cost=dict(type='LaneL1Cost', weight=0.0075), reg_cost=dict(type='LaneL1Cost', weight=0.0075),
iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head. iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head.
)))) ))))
train_pipeline = [ train_pipeline = [
dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True), dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
...@@ -261,7 +259,7 @@ train_pipeline = [ ...@@ -261,7 +259,7 @@ train_pipeline = [
'gt_topology_lclc', 'gt_topology_lcte', 'gt_topology_lclc', 'gt_topology_lcte',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
...@@ -279,7 +277,7 @@ test_pipeline = [ ...@@ -279,7 +277,7 @@ test_pipeline = [
'img', 'img',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
...@@ -350,4 +348,4 @@ dist_params = dict(backend='nccl') ...@@ -350,4 +348,4 @@ dist_params = dict(backend='nccl')
log_level = 'INFO' log_level = 'INFO'
load_from = None load_from = None
resume_from = None resume_from = None
workflow = [('train', 1)] workflow = [('train', 1)]
\ No newline at end of file
...@@ -18,14 +18,13 @@ input_modality = dict( ...@@ -18,14 +18,13 @@ input_modality = dict(
use_external=False) use_external=False)
num_cams = 7 num_cams = 7
Map_size = [(-50, 50), (-25, 25)] Map_size = [(-50, 50), (-25, 25)]
method_para = dict(n_control=5) # #point for each curve method_para = dict(n_control=5) # #point for each curve
code_size = 3 * method_para['n_control'] code_size = 3 * method_para['n_control']
_dim_ = 256 _dim_ = 256
_pos_dim_ = _dim_//2 _pos_dim_ = _dim_ // 2
_ffn_dim_ = _dim_*2 _ffn_dim_ = _dim_ * 2
_ffn_cfg_ = dict( _ffn_cfg_ = dict(
type='FFN', type='FFN',
embed_dims=_dim_, embed_dims=_dim_,
...@@ -78,7 +77,7 @@ model = dict( ...@@ -78,7 +77,7 @@ model = dict(
pc_range=point_cloud_range, pc_range=point_cloud_range,
bev_h=bev_h_, bev_h=bev_h_,
bev_w=bev_w_, bev_w=bev_w_,
rotate_center=[bev_h_//2, bev_w_//2], rotate_center=[bev_h_ // 2, bev_w_ // 2],
encoder=dict( encoder=dict(
type='BEVFormerEncoder', type='BEVFormerEncoder',
num_layers=3, num_layers=3,
...@@ -106,7 +105,7 @@ model = dict( ...@@ -106,7 +105,7 @@ model = dict(
], ],
ffn_cfgs=_ffn_cfg_, ffn_cfgs=_ffn_cfg_,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))), 'ffn', 'norm'))),
positional_encoding=dict( positional_encoding=dict(
type='LearnedPositionalEncoding', type='LearnedPositionalEncoding',
num_feats=_pos_dim_, num_feats=_pos_dim_,
...@@ -176,7 +175,7 @@ model = dict( ...@@ -176,7 +175,7 @@ model = dict(
with_box_refine=False, with_box_refine=False,
with_shared_param=False, with_shared_param=False,
code_size=code_size, code_size=code_size,
code_weights= [1.0 for i in range(code_size)], code_weights=[1.0 for i in range(code_size)],
pc_range=point_cloud_range, pc_range=point_cloud_range,
transformer=dict( transformer=dict(
type='PerceptionTransformer', type='PerceptionTransformer',
...@@ -193,7 +192,7 @@ model = dict( ...@@ -193,7 +192,7 @@ model = dict(
embed_dims=_dim_, embed_dims=_dim_,
num_heads=8, num_heads=8,
dropout=0.1), dropout=0.1),
dict( dict(
type='CustomMSDeformableAttention', type='CustomMSDeformableAttention',
embed_dims=_dim_, embed_dims=_dim_,
num_levels=1), num_levels=1),
...@@ -247,9 +246,8 @@ model = dict( ...@@ -247,9 +246,8 @@ model = dict(
type='LaneHungarianAssigner', type='LaneHungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=1.5), cls_cost=dict(type='FocalLossCost', weight=1.5),
reg_cost=dict(type='LaneL1Cost', weight=0.0075), reg_cost=dict(type='LaneL1Cost', weight=0.0075),
iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head. iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head.
)))) ))))
train_pipeline = [ train_pipeline = [
dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True), dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
...@@ -268,7 +266,7 @@ train_pipeline = [ ...@@ -268,7 +266,7 @@ train_pipeline = [
'gt_topology_lclc', 'gt_topology_lcte', 'gt_topology_lclc', 'gt_topology_lcte',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
...@@ -286,7 +284,7 @@ test_pipeline = [ ...@@ -286,7 +284,7 @@ test_pipeline = [
'img', 'img',
], ],
meta_keys=[ meta_keys=[
'scene_token', 'sample_idx', 'img_paths', 'scene_token', 'sample_idx', 'img_paths',
'img_shape', 'scale_factor', 'pad_shape', 'img_shape', 'scale_factor', 'pad_shape',
'lidar2img', 'can_bus', 'lidar2img', 'can_bus',
], ],
...@@ -357,4 +355,4 @@ dist_params = dict(backend='nccl') ...@@ -357,4 +355,4 @@ dist_params = dict(backend='nccl')
log_level = 'INFO' log_level = 'INFO'
load_from = None load_from = None
resume_from = None resume_from = None
workflow = [('train', 1)] workflow = [('train', 1)]
\ No newline at end of file
tqdm chardet
ninja iso3166
jupyter jupyter
openmim
matplotlib matplotlib
ninja
numpy >=1.22.0, <1.24.0 numpy >=1.22.0, <1.24.0
scikit-learn
similaritymeasures
opencv-python opencv-python
scipy ==1.8.0 openmim
ortools ==9.2.9972 ortools ==9.2.9972
iso3166 scikit-learn
chardet scipy ==1.8.0
similaritymeasures
tqdm
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# setup.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # setup.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from setuptools import setup, find_packages from setuptools import find_packages, setup
setup( setup(
name='openlanev2', name='openlanev2',
......
...@@ -66,7 +66,7 @@ def plot_curve(log_dicts, args): ...@@ -66,7 +66,7 @@ def plot_curve(log_dicts, args):
else: else:
# find the first epoch that do eval # find the first epoch that do eval
x0 = min(epochs) + args.interval - \ x0 = min(epochs) + args.interval - \
min(epochs) % args.interval min(epochs) % args.interval
xs = np.arange(x0, max(epochs) + 1, args.interval) xs = np.arange(x0, max(epochs) + 1, args.interval)
ys = [] ys = []
for epoch in epochs[args.interval - 1::args.interval]: for epoch in epochs[args.interval - 1::args.interval]:
...@@ -86,7 +86,7 @@ def plot_curve(log_dicts, args): ...@@ -86,7 +86,7 @@ def plot_curve(log_dicts, args):
xs = [] xs = []
ys = [] ys = []
num_iters_per_epoch = \ num_iters_per_epoch = \
log_dict[epochs[args.interval-1]]['iter'][-1] log_dict[epochs[args.interval - 1]]['iter'][-1]
for epoch in epochs[args.interval - 1::args.interval]: for epoch in epochs[args.interval - 1::args.interval]:
iters = log_dict[epoch]['iter'] iters = log_dict[epoch]['iter']
if log_dict[epoch]['mode'][-1] == 'val': if log_dict[epoch]['mode'][-1] == 'val':
...@@ -153,7 +153,7 @@ def add_time_parser(subparsers): ...@@ -153,7 +153,7 @@ def add_time_parser(subparsers):
'--include-outliers', '--include-outliers',
action='store_true', action='store_true',
help='include the first value of every epoch when computing ' help='include the first value of every epoch when computing '
'the average time') 'the average time')
def parse_args(): def parse_args():
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from mmcv import Config from mmcv import Config
from mmcv.parallel import MMDataParallel from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint, wrap_fp16_model from mmcv.runner import load_checkpoint, wrap_fp16_model
from mmdet3d.datasets import build_dataloader, build_dataset from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector from mmdet3d.models import build_detector
from tools.misc.fuse_conv_bn import fuse_module from tools.misc.fuse_conv_bn import fuse_module
...@@ -23,7 +22,7 @@ def parse_args(): ...@@ -23,7 +22,7 @@ def parse_args():
'--fuse-conv-bn', '--fuse-conv-bn',
action='store_true', action='store_true',
help='Whether to fuse conv and bn, this will slightly increase' help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed') 'the inference speed')
args = parser.parse_args() args = parser.parse_args()
return args return args
......
...@@ -3,7 +3,6 @@ import argparse ...@@ -3,7 +3,6 @@ import argparse
import torch import torch
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmdet3d.models import build_model from mmdet3d.models import build_model
try: try:
...@@ -32,17 +31,16 @@ def parse_args(): ...@@ -32,17 +31,16 @@ def parse_args():
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(): def main():
args = parse_args() args = parse_args()
if args.modality == 'point': if args.modality == 'point':
...@@ -52,7 +50,7 @@ def main(): ...@@ -52,7 +50,7 @@ def main():
if len(args.shape) == 1: if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0]) input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2: elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape) input_shape = (3,) + tuple(args.shape)
else: else:
raise ValueError('invalid input shape') raise ValueError('invalid input shape')
elif args.modality == 'multi': elif args.modality == 'multi':
......
...@@ -6,12 +6,11 @@ import mmcv ...@@ -6,12 +6,11 @@ import mmcv
import numpy as np import numpy as np
from mmcv import track_iter_progress from mmcv import track_iter_progress
from mmcv.ops import roi_align from mmcv.ops import roi_align
from pycocotools import mask as maskUtils
from pycocotools.coco import COCO
from mmdet3d.core.bbox import box_np_ops as box_np_ops from mmdet3d.core.bbox import box_np_ops as box_np_ops
from mmdet3d.datasets import build_dataset from mmdet3d.datasets import build_dataset
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from pycocotools import mask as maskUtils
from pycocotools.coco import COCO
def _poly2mask(mask_ann, img_h, img_w): def _poly2mask(mask_ann, img_h, img_w):
......
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import mmcv import mmcv
import numpy as np import numpy as np
from tools.data_converter.s3dis_data_utils import S3DISData, S3DISSegData from tools.data_converter.s3dis_data_utils import S3DISData, S3DISSegData
from tools.data_converter.scannet_data_utils import ScanNetData, ScanNetSegData from tools.data_converter.scannet_data_utils import ScanNetData, ScanNetSegData
from tools.data_converter.sunrgbd_data_utils import SUNRGBDData from tools.data_converter.sunrgbd_data_utils import SUNRGBDData
......
...@@ -4,9 +4,9 @@ from pathlib import Path ...@@ -4,9 +4,9 @@ from pathlib import Path
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet3d.core.bbox import box_np_ops, points_cam2img
from nuscenes.utils.geometry_utils import view_points from nuscenes.utils.geometry_utils import view_points
from mmdet3d.core.bbox import box_np_ops, points_cam2img
from .kitti_data_utils import WaymoInfoGatherer, get_kitti_image_info from .kitti_data_utils import WaymoInfoGatherer, get_kitti_image_info
from .nuscenes_converter import post_process_coords from .nuscenes_converter import post_process_coords
...@@ -507,7 +507,7 @@ def get_2d_boxes(info, occluded, mono3d=True): ...@@ -507,7 +507,7 @@ def get_2d_boxes(info, occluded, mono3d=True):
src = np.array([0.5, 1.0, 0.5]) src = np.array([0.5, 1.0, 0.5])
loc = loc + dim * (dst - src) loc = loc + dim * (dst - src)
offset = (info['calib']['P2'][0, 3] - info['calib']['P0'][0, 3]) \ offset = (info['calib']['P2'][0, 3] - info['calib']['P0'][0, 3]) \
/ info['calib']['P2'][0, 0] / info['calib']['P2'][0, 0]
loc_3d = np.copy(loc) loc_3d = np.copy(loc)
loc_3d[0, 0] += offset loc_3d[0, 0] += offset
gt_bbox_3d = np.concatenate([loc, dim, rot], axis=1).astype(np.float32) gt_bbox_3d = np.concatenate([loc, dim, rot], axis=1).astype(np.float32)
......
...@@ -151,7 +151,7 @@ def get_label_anno(label_path): ...@@ -151,7 +151,7 @@ def get_label_anno(label_path):
if len(content) != 0 and len(content[0]) == 16: # have score if len(content) != 0 and len(content[0]) == 16: # have score
annotations['score'] = np.array([float(x[15]) for x in content]) annotations['score'] = np.array([float(x[15]) for x in content])
else: else:
annotations['score'] = np.zeros((annotations['bbox'].shape[0], )) annotations['score'] = np.zeros((annotations['bbox'].shape[0],))
index = list(range(num_objects)) + [-1] * (num_gt - num_objects) index = list(range(num_objects)) + [-1] * (num_gt - num_objects)
annotations['index'] = np.array(index, dtype=np.int32) annotations['index'] = np.array(index, dtype=np.int32)
annotations['group_ids'] = np.arange(num_gt, dtype=np.int32) annotations['group_ids'] = np.arange(num_gt, dtype=np.int32)
...@@ -547,9 +547,9 @@ def add_difficulty_to_annos(info): ...@@ -547,9 +547,9 @@ def add_difficulty_to_annos(info):
occlusion = annos['occluded'] occlusion = annos['occluded']
truncation = annos['truncated'] truncation = annos['truncated']
diff = [] diff = []
easy_mask = np.ones((len(dims), ), dtype=np.bool) easy_mask = np.ones((len(dims),), dtype=np.bool)
moderate_mask = np.ones((len(dims), ), dtype=np.bool) moderate_mask = np.ones((len(dims),), dtype=np.bool)
hard_mask = np.ones((len(dims), ), dtype=np.bool) hard_mask = np.ones((len(dims),), dtype=np.bool)
i = 0 i = 0
for h, o, t in zip(height, occlusion, truncation): for h, o, t in zip(height, occlusion, truncation):
if o > max_occlusion[0] or h <= min_height[0] or t > max_trunc[0]: if o > max_occlusion[0] or h <= min_height[0] or t > max_trunc[0]:
......
...@@ -6,9 +6,9 @@ from os import path as osp ...@@ -6,9 +6,9 @@ from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft
from mmdet3d.datasets import LyftDataset
from pyquaternion import Quaternion from pyquaternion import Quaternion
from mmdet3d.datasets import LyftDataset
from .nuscenes_converter import (get_2d_boxes, get_available_scenes, from .nuscenes_converter import (get_2d_boxes, get_available_scenes,
obtain_sensor2top) obtain_sensor2top)
......
...@@ -11,7 +11,7 @@ def fix_lyft(root_folder='./data/lyft', version='v1.01'): ...@@ -11,7 +11,7 @@ def fix_lyft(root_folder='./data/lyft', version='v1.01'):
root_folder = os.path.join(root_folder, f'{version}-train') root_folder = os.path.join(root_folder, f'{version}-train')
lidar_path = os.path.join(root_folder, lidar_path) lidar_path = os.path.join(root_folder, lidar_path)
assert os.path.isfile(lidar_path), f'Please download the complete Lyft ' \ assert os.path.isfile(lidar_path), f'Please download the complete Lyft ' \
f'dataset and make sure {lidar_path} is present.' f'dataset and make sure {lidar_path} is present.'
points = np.fromfile(lidar_path, dtype=np.float32, count=-1) points = np.fromfile(lidar_path, dtype=np.float32, count=-1)
try: try:
points.reshape([-1, 5]) points.reshape([-1, 5])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment