Unverified Commit ac0302ed authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

Support CenterFormer in `projects` (#2173)

* init centerformer in projects

* add readme and disable_tf32_switch

* using our iou3d and nms3d, using basicblock in mmdet, simplify code in projects

* remove attention.py and sparse_block.py

* only using single fold

* polish code

* polish code and add dosstring

* add ut for disable_object_sample_hook

* modify data_root

* add ut

* update readme

* polish code

* fix docstring

* resolve comments

* modify project names

* modify project names and add _forward

* fix docstring

* remove disable_tf32
parent edc468bf
......@@ -67,11 +67,10 @@ def init_model(config: Union[str, Path, Config],
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = dataset_meta
model.dataset_meta = checkpoint['meta']['dataset_meta']
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
......
......@@ -532,6 +532,8 @@ class LoadPointsFromFile(BaseTransform):
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
norm_intensity (bool): Whether to normlize the intensity. Defaults to
False.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
......@@ -544,6 +546,7 @@ class LoadPointsFromFile(BaseTransform):
use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False,
use_color: bool = False,
norm_intensity: bool = False,
file_client_args: dict = dict(backend='disk')
) -> None:
self.shift_height = shift_height
......@@ -557,6 +560,7 @@ class LoadPointsFromFile(BaseTransform):
self.coord_type = coord_type
self.load_dim = load_dim
self.use_dim = use_dim
self.norm_intensity = norm_intensity
self.file_client_args = file_client_args.copy()
self.file_client = None
......@@ -599,6 +603,10 @@ class LoadPointsFromFile(BaseTransform):
points = self._load_points(pts_file_path)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
attribute_dims = None
if self.shift_height:
......
......@@ -359,6 +359,7 @@ class ObjectSample(BaseTransform):
db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = TRANSFORMS.build(db_sampler)
self.use_ground_plane = use_ground_plane
self.disabled = False
@staticmethod
def remove_points_in_boxes(points: BasePoints,
......@@ -387,6 +388,9 @@ class ObjectSample(BaseTransform):
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
"""
if self.disabled:
return input_dict
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
......
# Copyright (c) OpenMMLab. All rights reserved.
from .benchmark_hook import BenchmarkHook
from .disable_object_sample_hook import DisableObjectSampleHook
from .visualization_hook import Det3DVisualizationHook
__all__ = ['Det3DVisualizationHook', 'BenchmarkHook']
__all__ = [
'Det3DVisualizationHook', 'BenchmarkHook', 'DisableObjectSampleHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmdet3d.datasets.transforms import ObjectSample
from mmdet3d.registry import HOOKS
@HOOKS.register_module()
class DisableObjectSampleHook(Hook):
"""The hook of disabling augmentations during training.
Args:
disable_after_epoch (int): The number of epochs after which
the ``ObjectSample`` will be closed in the training.
Defaults to 15.
"""
def __init__(self, disable_after_epoch: int = 15):
self.disable_after_epoch = disable_after_epoch
self._restart_dataloader = False
def before_train_epoch(self, runner: Runner):
"""Close augmentation.
Args:
runner (Runner): The runner.
"""
epoch = runner.epoch
train_loader = runner.train_dataloader
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
if epoch == self.disable_after_epoch:
runner.logger.info('Disable ObjectSample')
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
if isinstance(transform, ObjectSample):
assert hasattr(transform, 'disabled')
transform.disabled = True
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
......@@ -6,9 +6,7 @@ try:
except ImportError:
IS_SPCONV2_AVAILABLE = False
else:
if hasattr(spconv,
'__version__') and spconv.__version__ >= '2.0.0' and hasattr(
spconv, 'pytorch'):
if hasattr(spconv, '__version__') and spconv.__version__ >= '2.0.0':
IS_SPCONV2_AVAILABLE = register_spconv2()
else:
IS_SPCONV2_AVAILABLE = False
......
# CenterFormer: Center-based Transformer for 3D Object Detection
> [CenterFormer: Center-based Transformer for 3D Object Detection](https://arxiv.org/abs/2209.05588)
<!-- [ALGORITHM] -->
## Abstract
Query-based transformer has shown great potential in con-
structing long-range attention in many image-domain tasks, but has
rarely been considered in LiDAR-based 3D object detection due to the
overwhelming size of the point cloud data. In this paper, we propose
CenterFormer, a center-based transformer network for 3D object de-
tection. CenterFormer first uses a center heatmap to select center candi-
dates on top of a standard voxel-based point cloud encoder. It then uses
the feature of the center candidate as the query embedding in the trans-
former. To further aggregate features from multiple frames, we design
an approach to fuse features through cross-attention. Lastly, regression
heads are added to predict the bounding box on the output center feature
representation. Our design reduces the convergence difficulty and compu-
tational complexity of the transformer structure. The results show signif-
icant improvements over the strong baseline of anchor-free object detec-
tion networks. CenterFormer achieves state-of-the-art performance for a
single model on the Waymo Open Dataset, with 73.7% mAPH on the val-
idation set and 75.6% mAPH on the test set, significantly outperforming
all previously published CNN and transformer-based methods. Our code
is publicly available at https://github.com/TuSimple/centerformer
<div align=center>
<img src="https://user-images.githubusercontent.com/34888372/209500088-b707d7cd-d4d5-4f20-8fdf-a2c7ad15df34.png" width="800"/>
</div>
## Introduction
We implement CenterFormer and provide the result and checkpoints on Waymo dataset.
We follow the below style to name config files. Contributors are advised to follow the same style.
`{xxx}` is required field and `[yyy]` is optional.
`{model}`: model type like `centerpoint`.
`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`.
`{backbone}`: backbone type like `second`.
`{neck}`: neck type like `secfpn`.
`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default.
`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs.
`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively.
## Usage
<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->
### Training commands
In MMDetection3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```
For multi-gpu training, run:
```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```
### Testing commands
In MMDetection3D's root directory, run the following command to test the model:
```bash
python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH}
```
## Results and models
### Waymo
| Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :----------------------------------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [SECFPN_WithAttention](./configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py) | 5 | voxel (0.1) | ✓ | × | 14.8 | | 72.2 | 69.5 | 65.9 | 63.3 | [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/centerformer/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class_20221227_205613-70c9ad37.json) |
**Note** that `SECFPN_WithAttention` denotes both SECOND and SECONDFPN with ChannelAttention and SpatialAttention.
## Citation
```latex
@InProceedings{Zhou_centerformer,
title = {CenterFormer: Center-based Transformer for 3D Object Detection},
author = {Zhou, Zixiang and Zhao, Xiangchen and Wang, Yu and Wang, Panqu and Foroosh, Hassan},
booktitle = {ECCV},
year = {2022}
}
```
from .bbox_ops import nms_iou3d
from .centerformer import CenterFormer
from .centerformer_backbone import (DeformableDecoderRPN,
MultiFrameDeformableDecoderRPN)
from .centerformer_head import CenterFormerBboxHead
from .losses import FastFocalLoss
__all__ = [
'CenterFormer', 'DeformableDecoderRPN', 'CenterFormerBboxHead',
'FastFocalLoss', 'nms_iou3d', 'MultiFrameDeformableDecoderRPN'
]
import torch
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['iou3d_nms3d_forward'])
def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
"""NMS function GPU implementation (using IoU3D). The difference between
this implementation and nms3d in MMCV is that we add `pre_maxsize` and
`post_max_size` before and after NMS respectively.
Args:
boxes (Tensor): Input boxes with the shape of [N, 7]
([cx, cy, cz, l, w, h, theta]).
scores (Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Defaults to None.
post_max_size (int, optional): Max size of boxes after NMS.
Defaults to None.
Returns:
Tensor: Indexes after NMS.
"""
# TODO: directly refactor ``nms3d`` in MMCV
assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)'
order = scores.sort(0, descending=True)[1]
if pre_maxsize is not None:
order = order[:pre_maxsize]
boxes = boxes[order].contiguous()
keep = boxes.new_zeros(boxes.size(0), dtype=torch.long)
num_out = boxes.new_zeros(size=(), dtype=torch.long)
ext_module.iou3d_nms3d_forward(
boxes, keep, num_out, nms_overlap_thresh=thresh)
keep = order[keep[:num_out].to(boxes.device)].contiguous()
if post_max_size is not None:
keep = keep[:post_max_size]
return keep
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
@MODELS.register_module()
class CenterFormer(Base3DDetector):
"""Base class of center-based 3D detector.
Args:
voxel_encoder (dict, optional): Point voxelization
encoder layer. Defaults to None.
middle_encoder (dict, optional): Middle encoder layer
of points cloud modality. Defaults to None.
pts_fusion_layer (dict, optional): Fusion layer.
Defaults to None.
backbone (dict, optional): Backbone of extracting
points features. Defaults to None.
neck (dict, optional): Neck of extracting
points features. Defaults to None.
bbox_head (dict, optional): Bboxes head of
point cloud modality. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: Optional[dict] = None,
middle_encoder: Optional[dict] = None,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs):
super(CenterFormer, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs)
if voxel_encoder:
self.voxel_encoder = MODELS.build(voxel_encoder)
if middle_encoder:
self.middle_encoder = MODELS.build(middle_encoder)
if backbone:
backbone.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
if bbox_head:
bbox_head.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def init_weights(self):
for m in self.modules():
if isinstance(m, _BatchNorm):
torch.nn.init.uniform_(m.weight)
@property
def with_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'backbone') and self.backbone is not None
@property
def with_voxel_encoder(self):
"""bool: Whether the detector has a voxel encoder."""
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
"""bool: Whether the detector has a middle encoder."""
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def _forward(self):
pass
def extract_feat(self, batch_inputs_dict: dict,
batch_input_metas: List[dict]) -> tuple:
"""Extract features from images and points.
Args:
batch_inputs_dict (dict): Dict of batch inputs. It
contains
- points (List[tensor]): Point cloud of multiple inputs.
- imgs (tensor): Image tensor with shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of multiple inputs
in a batch.
Returns:
tuple: Two elements in tuple arrange as
image features and point cloud features.
"""
voxel_dict = batch_inputs_dict.get('voxels', None)
voxel_features, feature_coors = self.voxel_encoder(
voxel_dict['voxels'], voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
return x
def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' and `imgs` keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Tensor of batch images, has shape
(B, C, H ,W)
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
preds, batch_tatgets = self.backbone(pts_feats, batch_data_samples)
preds = self.bbox_head(preds)
losses = dict()
losses.update(self.bbox_head.loss(preds, batch_tatgets))
return losses
# return self.bbox_head.predict(preds, batch_tatgets)
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
preds, _ = self.backbone(pts_feats, batch_data_samples)
preds = self.bbox_head(preds)
results_list_3d = self.bbox_head.predict(preds, batch_input_metas)
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return detsamples
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/necks/rpn_transformer.py # noqa
from typing import List, Tuple
import numpy as np
import torch
from mmcv.cnn import build_norm_layer
from mmdet.models.utils import multi_apply
from mmengine.logging import print_log
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius
from mmdet3d.registry import MODELS
from mmdet3d.structures import center_to_corner_box2d
from .transformer import DeformableTransformerDecoder
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False),
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out) * x
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
y = torch.cat([avg_out, max_out], dim=1)
y = self.conv1(y)
return self.sigmoid(y) * x
class MultiFrameSpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(MultiFrameSpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, curr, prev):
avg_out = torch.mean(curr, dim=1, keepdim=True)
max_out, _ = torch.max(curr, dim=1, keepdim=True)
y = torch.cat([avg_out, max_out], dim=1)
y = self.conv1(y)
return self.sigmoid(y) * prev
class BaseDecoderRPN(nn.Module):
def __init__(
self,
layer_nums, # [2,2,2]
ds_num_filters, # [128,256,64]
num_input_features, # 256
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
**kwargs):
super(BaseDecoderRPN, self).__init__()
self._layer_strides = [1, 2, -4]
self._num_filters = ds_num_filters
self._layer_nums = layer_nums
self._num_input_features = num_input_features
self.score_threshold = score_threshold
self.transformer_config = transformer_config
self.corner = corner
self.obj_num = obj_num
self.use_gt_training = use_gt_training
self.window_size = assign_label_window_size**2
self.cross_attention_kernel_size = [3, 3, 3]
self.batch_id = None
if norm_cfg is None:
norm_cfg = dict(type='BN', eps=1e-3, momentum=0.01)
self._norm_cfg = norm_cfg
assert len(self._layer_strides) == len(self._layer_nums)
assert len(self._num_filters) == len(self._layer_nums)
assert self.transformer_config is not None
in_filters = [
self._num_input_features,
self._num_filters[0],
self._num_filters[1],
]
blocks = []
for i, layer_num in enumerate(self._layer_nums):
block, num_out_filters = self._make_layer(
in_filters[i],
self._num_filters[i],
layer_num,
stride=self._layer_strides[i],
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.up = nn.Sequential(
nn.ConvTranspose2d(
self._num_filters[0],
self._num_filters[2],
2,
stride=2,
bias=False),
build_norm_layer(self._norm_cfg, self._num_filters[2])[1],
nn.ReLU())
# heatmap prediction
hm_head = []
for i in range(hm_head_layer - 1):
hm_head.append(
nn.Conv2d(
self._num_filters[-1] * 2,
64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
))
hm_head.append(build_norm_layer(self._norm_cfg, 64)[1])
hm_head.append(nn.ReLU())
hm_head.append(
nn.Conv2d(
64, classes, kernel_size=3, stride=1, padding=1, bias=True))
hm_head[-1].bias.data.fill_(init_bias)
self.hm_head = nn.Sequential(*hm_head)
if self.corner:
self.corner_head = []
for i in range(corner_head_layer - 1):
self.corner_head.append(
nn.Conv2d(
self._num_filters[-1] * 2,
64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
))
self.corner_head.append(
build_norm_layer(self._norm_cfg, 64)[1])
self.corner_head.append(nn.ReLU())
self.corner_head.append(
nn.Conv2d(
64, 1, kernel_size=3, stride=1, padding=1, bias=True))
self.corner_head[-1].bias.data.fill_(init_bias)
self.corner_head = nn.Sequential(*self.corner_head)
def _make_layer(self, inplanes, planes, num_blocks, stride=1):
if stride > 0:
block = [
nn.ZeroPad2d(1),
nn.Conv2d(inplanes, planes, 3, stride=stride, bias=False),
build_norm_layer(self._norm_cfg, planes)[1],
nn.ReLU(),
]
else:
block = [
nn.ConvTranspose2d(
inplanes, planes, -stride, stride=-stride, bias=False),
build_norm_layer(self._norm_cfg, planes)[1],
nn.ReLU(),
]
for j in range(num_blocks):
block.append(nn.Conv2d(planes, planes, 3, padding=1, bias=False))
block.append(build_norm_layer(self._norm_cfg, planes)[1], )
block.append(nn.ReLU())
block.append(ChannelAttention(planes))
block.append(SpatialAttention())
block = nn.Sequential(*block)
return block, planes
def forward(self, x, example=None):
pass
def get_multi_scale_feature(self, center_pos, feats):
"""
Args:
center_pos: center coor at the lowest scale feature map [B 500 2]
feats: multi scale BEV feature 3*[B C H W]
Returns:
neighbor_feat: [B 500 K C]
neighbor_pos: [B 500 K 2]
"""
kernel_size = self.cross_attention_kernel_size
batch, num_cls, H, W = feats[0].size()
center_num = center_pos.shape[1]
relative_pos_list = []
neighbor_feat_list = []
for i, k in enumerate(kernel_size):
neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
neighbor_coords = torch.flatten(
torch.stack(
torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
1,
) # [2, k]
neighbor_coords = (neighbor_coords.permute(
1,
0).contiguous().to(center_pos)) # relative coordinate [k, 2]
neighbor_coords = (center_pos[:, :, None, :] // (2**i) +
neighbor_coords[None, None, :, :]
) # coordinates [B, 500, k, 2]
neighbor_coords = torch.clamp(
neighbor_coords, min=0,
max=H // (2**i) - 1) # prevent out of bound
feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) +
neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k]
feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
selected_feat = (
feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute(
0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1,
num_cls)) # B, 500, k, C
relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
return neighbor_feats, neighbor_pos
def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe):
"""
Args:
center_pos: center coor at the lowest scale feature map [B 500 2]
feats: multi scale BEV feature (3+k)*[B C H W]
timeframe: timeframe [B,k]
Returns:
neighbor_feat: [B 500 K C]
neighbor_pos: [B 500 K 2]
neighbor_time: [B 500 K 1]
"""
kernel_size = self.cross_attention_kernel_size
batch, num_cls, H, W = feats[0].size()
center_num = center_pos.shape[1]
relative_pos_list = []
neighbor_feat_list = []
timeframe_list = []
for i, k in enumerate(kernel_size):
neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
neighbor_coords = torch.flatten(
torch.stack(
torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
1,
) # [2, k]
neighbor_coords = (neighbor_coords.permute(
1,
0).contiguous().to(center_pos)) # relative coordinate [k, 2]
neighbor_coords = (center_pos[:, :, None, :] // (2**i) +
neighbor_coords[None, None, :, :]
) # coordinates [B, 500, k, 2]
neighbor_coords = torch.clamp(
neighbor_coords, min=0,
max=H // (2**i) - 1) # prevent out of bound
feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) +
neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k]
feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
selected_feat = (
feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute(
0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1,
num_cls)) # B, 500, k, C
relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
timeframe_list.append(
torch.full_like(neighbor_coords[:, :, :, 0:1], 0)) # B, 500, k
if i == 0:
# add previous frame feature
for frame_num in range(feats[-1].shape[1]):
selected_feat = (feats[-1][:, frame_num, :, :, :].reshape(
batch, num_cls, (H * W) // (4**i)).permute(
0, 2,
1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1, num_cls))
relative_pos_list.append(neighbor_coords * (2**i))
time = timeframe[:, frame_num + 1].to(selected_feat) # B
timeframe_list.append(
time[:, None, None, None] * torch.full_like(
neighbor_coords[:, :, :, 0:1], 1)) # B, 500, k
neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
neighbor_time = torch.cat(timeframe_list, dim=2) # B, 500, K, 1
return neighbor_feats, neighbor_pos, neighbor_time
@MODELS.register_module()
class DeformableDecoderRPN(BaseDecoderRPN):
"""The original implement of CenterFormer modules.
It fuse the backbone, neck and heatmap head into one module. The backbone
is `SECOND` with attention and the neck is `SECONDFPN` with attention.
TODO: split this module into backbone、neck and head.
"""
def __init__(self,
layer_nums,
ds_num_filters,
num_input_features,
tasks=dict(),
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
parametric_embedding=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
train_cfg=None,
test_cfg=None,
**kwargs):
super(DeformableDecoderRPN, self).__init__(
layer_nums,
ds_num_filters,
num_input_features,
transformer_config,
hm_head_layer,
corner_head_layer,
corner,
assign_label_window_size,
classes,
use_gt_training,
norm_cfg,
logger,
init_bias,
score_threshold,
obj_num,
)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.tasks = tasks
self.class_names = [t['class_names'] for t in tasks]
self.transformer_decoder = DeformableTransformerDecoder(
self._num_filters[-1] * 2,
depth=transformer_config.depth,
n_heads=transformer_config.n_heads,
dim_single_head=transformer_config.dim_single_head,
dim_ffn=transformer_config.dim_ffn,
dropout=transformer_config.dropout,
out_attention=transformer_config.out_attn,
n_points=transformer_config.get('n_points', 9),
)
self.pos_embedding_type = transformer_config.get(
'pos_embedding_type', 'linear')
if self.pos_embedding_type == 'linear':
self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
else:
raise NotImplementedError()
self.parametric_embedding = parametric_embedding
if self.parametric_embedding:
self.query_embed = nn.Embedding(self.obj_num,
self._num_filters[-1] * 2)
nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
print_log('Finish RPN_transformer_deformable Initialization',
'current')
def _sigmoid(self, x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y
def forward(self, x, batch_data_samples):
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
# FPN
x = self.blocks[0](x)
x_down = self.blocks[1](x)
x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
# heatmap head
hm = self.hm_head(x_up)
if self.corner and self.corner_head.training:
corner_hm = self.corner_head(x_up)
corner_hm = self._sigmoid(corner_hm)
# find top K center location
hm = self._sigmoid(hm)
batch, num_cls, H, W = hm.size()
scores, labels = torch.max(
hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
self.batch_id = torch.from_numpy(np.indices(
(batch, self.obj_num))[0]).to(labels)
if self.training:
heatmaps, anno_boxes, gt_inds, gt_masks, corner_heatmaps, cat_labels = self.get_targets( # noqa: E501
batch_gt_instance_3d)
batch_targets = dict(
ind=gt_inds,
mask=gt_masks,
hm=heatmaps,
anno_box=anno_boxes,
corners=corner_heatmaps,
cat=cat_labels)
inds = gt_inds[0][:, (self.window_size // 2)::self.window_size]
masks = gt_masks[0][:, (self.window_size // 2)::self.window_size]
batch_id_gt = torch.from_numpy(
np.indices((batch, inds.shape[1]))[0]).to(labels)
scores[batch_id_gt, inds] = scores[batch_id_gt, inds] + masks
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores[batch_id_gt, inds] = scores[batch_id_gt, inds] - masks
else:
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
batch_targets = None
scores = torch.gather(scores, 1, order)
labels = torch.gather(labels, 1, order)
mask = scores > self.score_threshold
ct_feat = x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous()
ct_feat = ct_feat[self.batch_id, order] # B, 500, C
# create position embedding for each center
y_coor = order // W
x_coor = order - y_coor * W
y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
y_coor, x_coor = y_coor / H, x_coor / W
pos_features = torch.stack([x_coor, y_coor], dim=2)
if self.parametric_embedding:
ct_feat = self.query_embed.weight
ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
# run transformer
src = torch.cat(
(
x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
x.reshape(batch, -1,
(H * W) // 4).transpose(2, 1).contiguous(),
x_down.reshape(batch, -1,
(H * W) // 16).transpose(2, 1).contiguous(),
),
dim=1,
) # B ,sum(H*W), C
spatial_shapes = torch.as_tensor(
[(H, W), (H // 2, W // 2), (H // 4, W // 4)],
dtype=torch.long,
device=ct_feat.device,
)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )),
spatial_shapes.prod(1).cumsum(0)[:-1],
))
transformer_out = self.transformer_decoder(
ct_feat,
self.pos_embedding,
src,
spatial_shapes,
level_start_index,
center_pos=pos_features,
) # (B,N,C)
ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous()
) # B, C, 500
out_dict = {
'hm': hm,
'scores': scores,
'labels': labels,
'order': order,
'ct_feat': ct_feat,
'mask': mask,
}
if 'out_attention' in transformer_out:
out_dict.update(
{'out_attention': transformer_out['out_attention']})
if self.corner and self.corner_head.training:
out_dict.update({'corner_hm': corner_hm})
return out_dict, batch_targets
def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets. How each output is transformed: Each nested list
is transposed so that all same-index elements in each sub-list (1, ...,
N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which
boxes are valid.
- list[torch.Tensor]: catagrate labels.
"""
heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = multi_apply( # noqa: E501
self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose heatmaps
corner_heatmaps = list(map(list, zip(*corner_heatmaps)))
corner_heatmaps = [torch.stack(hms_) for hms_ in corner_heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
# Transpose cat_labels
cat_labels = list(map(list, zip(*cat_labels)))
cat_labels = [torch.stack(labels_) for labels_ in cat_labels]
return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels
def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes
are valid.
- list[torch.Tensor]: catagrate labels.
"""
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = [], [], [], [], [], [] # noqa: E501
for idx in range(len(self.tasks)):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
corner_heatmap = torch.zeros(
(1, feature_map_size[1], feature_map_size[0]),
dtype=torch.float32,
device=device)
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
cat_label = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.int64)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
width = task_boxes[idx][k][3]
length = task_boxes[idx][k][4]
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
radius = radius // 2
# # draw four corner and center TODO: use torch
rot = task_boxes[idx][k][6]
corner_keypoints = center_to_corner_box2d(
center.unsqueeze(0).cpu().numpy(),
torch.tensor([[width, length]],
dtype=torch.float32).numpy(),
angles=rot,
origin=0.5)
corner_keypoints = torch.from_numpy(corner_keypoints).to(
center)
draw_gaussian(corner_heatmap[0], center_int, radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 0] + corner_keypoints[0, 1]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 2] + corner_keypoints[0, 3]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 0] + corner_keypoints[0, 3]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 1] + corner_keypoints[0, 2]) / 2,
radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
cat_label[new_idx] = cls_id
# TODO: support other outdoor dataset
# vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap)
corner_heatmaps.append(corner_heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
cat_labels.append(cat_label)
return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels
@MODELS.register_module()
class MultiFrameDeformableDecoderRPN(BaseDecoderRPN):
"""The original implementation of CenterFormer modules.
The difference between this module and
`DeformableDecoderRPN` is that this module uses information from multi
frames.
TODO: split this module into backbone、neck and head.
"""
def __init__(
self,
layer_nums, # [2,2,2]
ds_num_filters, # [128,256,64]
num_input_features, # 256
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
parametric_embedding=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
frame=1,
**kwargs):
super(MultiFrameDeformableDecoderRPN, self).__init__(
layer_nums,
ds_num_filters,
num_input_features,
transformer_config,
hm_head_layer,
corner_head_layer,
corner,
assign_label_window_size,
classes,
use_gt_training,
norm_cfg,
logger,
init_bias,
score_threshold,
obj_num,
)
self.frame = frame
self.out = nn.Sequential(
nn.Conv2d(
self._num_filters[0] * frame,
self._num_filters[0],
3,
padding=1,
bias=False,
),
build_norm_layer(self._norm_cfg, self._num_filters[0])[1],
nn.ReLU(),
)
self.mtf_attention = MultiFrameSpatialAttention()
self.time_embedding = nn.Linear(1, self._num_filters[0])
self.transformer_decoder = DeformableTransformerDecoder(
self._num_filters[-1] * 2,
depth=transformer_config.depth,
n_heads=transformer_config.n_heads,
n_levels=2 + self.frame,
dim_single_head=transformer_config.dim_single_head,
dim_ffn=transformer_config.dim_ffn,
dropout=transformer_config.dropout,
out_attention=transformer_config.out_attn,
n_points=transformer_config.get('n_points', 9),
)
self.pos_embedding_type = transformer_config.get(
'pos_embedding_type', 'linear')
if self.pos_embedding_type == 'linear':
self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
else:
raise NotImplementedError()
self.parametric_embedding = parametric_embedding
if self.parametric_embedding:
self.query_embed = nn.Embedding(self.obj_num,
self._num_filters[-1] * 2)
nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
print_log('Finish RPN_transformer_deformable Initialization',
'current')
def forward(self, x, example=None):
# FPN
x = self.blocks[0](x)
x_down = self.blocks[1](x)
x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
# take out the BEV feature on current frame
x = torch.split(x, self.frame)
x_up = torch.split(x_up, self.frame)
x_down = torch.split(x_down, self.frame)
x_prev = torch.stack([t[1:] for t in x_up], dim=0) # B,K,C,H,W
x = torch.stack([t[0] for t in x], dim=0)
x_down = torch.stack([t[0] for t in x_down], dim=0)
x_up = torch.stack([t[0] for t in x_up], dim=0) # B,C,H,W
# use spatial attention in current frame on previous feature
x_prev_cat = self.mtf_attention(
x_up,
x_prev.reshape(x_up.shape[0], -1, x_up.shape[2],
x_up.shape[3])) # B,K*C,H,W
# time embedding
x_up_fuse = torch.cat((x_up, x_prev_cat), dim=1) + self.time_embedding(
example['times'][:, :, None].to(x_up)).reshape(
x_up.shape[0], -1, 1, 1)
# fuse mtf feature
x_up_fuse = self.out(x_up_fuse)
# heatmap head
hm = self.hm_head(x_up_fuse)
if self.corner and self.corner_head.training:
corner_hm = self.corner_head(x_up_fuse)
corner_hm = torch.sigmoid(corner_hm)
# find top K center location
hm = torch.sigmoid(hm)
batch, num_cls, H, W = hm.size()
scores, labels = torch.max(
hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
self.batch_id = torch.from_numpy(np.indices(
(batch, self.obj_num))[0]).to(labels)
if self.use_gt_training and self.hm_head.training:
gt_inds = example['ind'][0][:, (self.window_size //
2)::self.window_size]
gt_masks = example['mask'][0][:, (self.window_size //
2)::self.window_size]
batch_id_gt = torch.from_numpy(
np.indices((batch, gt_inds.shape[1]))[0]).to(labels)
scores[batch_id_gt,
gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores[batch_id_gt,
gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
else:
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores = torch.gather(scores, 1, order)
labels = torch.gather(labels, 1, order)
mask = scores > self.score_threshold
ct_feat = (x_up.reshape(batch, -1,
H * W).transpose(2,
1).contiguous()[self.batch_id,
order]
) # B, 500, C
# create position embedding for each center
y_coor = order // W
x_coor = order - y_coor * W
y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
y_coor, x_coor = y_coor / H, x_coor / W
pos_features = torch.stack([x_coor, y_coor], dim=2)
if self.parametric_embedding:
ct_feat = self.query_embed.weight
ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
# run transformer
src_list = [
x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
x.reshape(batch, -1, (H * W) // 4).transpose(2, 1).contiguous(),
x_down.reshape(batch, -1, (H * W) // 16).transpose(2,
1).contiguous(),
]
for frame in range(x_prev.shape[1]):
src_list.append(x_prev[:, frame].reshape(batch,
-1, (H * W)).transpose(
2, 1).contiguous())
src = torch.cat(src_list, dim=1) # B ,sum(H*W), C
spatial_list = [(H, W), (H // 2, W // 2), (H // 4, W // 4)]
spatial_list += [(H, W) for frame in range(x_prev.shape[1])]
spatial_shapes = torch.as_tensor(
spatial_list, dtype=torch.long, device=ct_feat.device)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )),
spatial_shapes.prod(1).cumsum(0)[:-1],
))
transformer_out = self.transformer_decoder(
ct_feat,
self.pos_embedding,
src,
spatial_shapes,
level_start_index,
center_pos=pos_features,
) # (B,N,C)
ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous()
) # B, C, 500
out_dict = {
'hm': hm,
'scores': scores,
'labels': labels,
'order': order,
'ct_feat': ct_feat,
'mask': mask,
}
if 'out_attention' in transformer_out:
out_dict.update(
{'out_attention': transformer_out['out_attention']})
if self.corner and self.corner_head.training:
out_dict.update({'corner_hm': corner_hm})
return out_dict
# ------------------------------------------------------------------------------
# Portions of this code are from
# det3d (https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07) # noqa
# Copyright (c) 2019 朱本金
# Licensed under the MIT License
# ------------------------------------------------------------------------------
import copy
import logging
import numpy as np
import torch
from mmcv.cnn import build_norm_layer
from mmcv.ops import boxes_iou3d
from mmengine.logging import print_log
from mmengine.model import kaiming_init
from mmengine.structures import InstanceData
from torch import nn
from mmdet3d.models.layers import circle_nms, nms_bev
from mmdet3d.registry import MODELS
from .bbox_ops import nms_iou3d
from .losses import FastFocalLoss
class SepHead(nn.Module):
"""TODO: This module is the original implementation in CenterFormer and it
has few differences with ``SeperateHead`` in `mmdet3d` but refactor this
module will lower the performance a little.
"""
def __init__(
self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
bn=False,
init_bias=-2.19,
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
**kwargs,
):
super(SepHead, self).__init__(**kwargs)
self.heads = heads
for head in self.heads:
classes, num_conv = self.heads[head]
fc = []
for i in range(num_conv - 1):
fc.append(
nn.Conv1d(
in_channels,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True,
))
if bn:
fc.append(build_norm_layer(norm_cfg, head_conv)[1])
fc.append(nn.ReLU())
fc.append(
nn.Conv1d(
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True,
))
if 'hm' in head:
fc[-1].bias.data.fill_(init_bias)
else:
for m in fc:
if isinstance(m, nn.Conv1d):
kaiming_init(m)
fc = nn.Sequential(*fc)
self.__setattr__(head, fc)
def forward(self, x, y):
for head in self.heads:
x[head] = self.__getattr__(head)(y)
return x
@MODELS.register_module()
class CenterFormerBboxHead(nn.Module):
def __init__(self,
in_channels,
tasks,
weight=0.25,
iou_weight=1,
corner_weight=1,
code_weights=[],
common_heads=dict(),
logger=None,
init_bias=-2.19,
share_conv_channel=64,
assign_label_window_size=1,
iou_loss=False,
corner_loss=False,
iou_factor=[1, 1, 4],
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
bbox_code_size=7,
test_cfg=None,
**kawrgs):
super(CenterFormerBboxHead, self).__init__()
num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks]
self.code_weights = code_weights
self.bbox_code_size = 7
self.weight = weight # weight between hm loss and loc loss
self.iou_weight = iou_weight
self.corner_weight = corner_weight
self.iou_factor = iou_factor
self.in_channels = in_channels
self.num_classes = num_classes
self.test_cfg = test_cfg
self.crit = FastFocalLoss(assign_label_window_size)
self.crit_reg = torch.nn.L1Loss(reduction='none')
self.use_iou_loss = iou_loss
if self.use_iou_loss:
self.crit_iou = torch.nn.SmoothL1Loss(reduction='none')
self.corner_loss = corner_loss
if self.corner_loss:
self.corner_crit = torch.nn.MSELoss(reduction='none')
self.box_n_dim = 9 if 'vel' in common_heads else 7
self.use_direction_classifier = False
if not logger:
logger = logging.getLogger('CenterFormerBboxHead')
self.logger = logger
logger.info(f'num_classes: {num_classes}')
# a shared convolution
self.shared_conv = nn.Sequential(
nn.Conv1d(
in_channels, share_conv_channel, kernel_size=1, bias=True),
build_norm_layer(norm_cfg, share_conv_channel)[1],
nn.ReLU(inplace=True),
)
self.tasks = nn.ModuleList()
print_log(f'Use HM Bias: {init_bias}', 'current')
for num_cls in num_classes:
heads = copy.deepcopy(common_heads)
self.tasks.append(
SepHead(
share_conv_channel,
heads,
bn=True,
init_bias=init_bias,
final_kernel=1,
norm_cfg=norm_cfg))
logger.info('Finish CenterHeadIoU Initialization')
def forward(self, x, *kwargs):
ret_dicts = []
y = self.shared_conv(x['ct_feat'].float())
for task in self.tasks:
ret_dicts.append(task(x, y))
return ret_dicts
def _sigmoid(self, x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y
def loss(self, preds_dicts, example, **kwargs):
losses = {}
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
hm_loss = self.crit(
preds_dict['hm'],
example['hm'][task_id],
example['ind'][task_id],
example['mask'][task_id],
example['cat'][task_id],
)
target_box = example['anno_box'][task_id]
if self.corner_loss:
corner_loss = self.corner_crit(preds_dict['corner_hm'],
example['corners'][task_id])
corner_mask = (example['corners'][task_id] > 0).to(corner_loss)
corner_loss = (corner_loss * corner_mask).sum() / (
corner_mask.sum() + 1e-4)
losses.update({
f'{task_id}_corner_loss':
corner_loss * self.corner_weight
})
# reconstruct the anno_box from multiple reg heads
if 'vel' in preds_dict:
preds_dict['anno_box'] = torch.cat(
(
preds_dict['reg'],
preds_dict['height'],
preds_dict['dim'],
preds_dict['vel'],
preds_dict['rot'],
),
dim=1,
)
else:
preds_dict['anno_box'] = torch.cat(
(
preds_dict['reg'],
preds_dict['height'],
preds_dict['dim'],
preds_dict['rot'],
),
dim=1,
)
target_box = target_box[..., [0, 1, 2, 3, 4, 5, -2,
-1]] # remove vel target
# Regression loss for dimension, offset, height, rotation
# get corresponding gt box # B, 500
target_box, selected_mask, selected_cls = get_corresponding_box(
preds_dict['order'],
example['ind'][task_id],
example['mask'][task_id],
example['cat'][task_id],
target_box,
)
mask = selected_mask.float().unsqueeze(2)
weights = self.code_weights
box_loss = self.crit_reg(
preds_dict['anno_box'].transpose(1, 2) * mask,
target_box * mask)
box_loss = box_loss / (mask.sum() + 1e-4)
box_loss = box_loss.transpose(2, 0).sum(dim=2).sum(dim=1)
loc_loss = (box_loss * box_loss.new_tensor(weights)).sum()
if self.use_iou_loss:
with torch.no_grad():
preds_box = get_box(
preds_dict['anno_box'],
preds_dict['order'],
self.test_cfg,
preds_dict['hm'].shape[2],
preds_dict['hm'].shape[3],
)
cur_gt = get_box_gt(
target_box,
preds_dict['order'],
self.test_cfg,
preds_dict['hm'].shape[2],
preds_dict['hm'].shape[3],
)
iou_targets = boxes_iou3d(
preds_box.reshape(-1, 7), cur_gt.reshape(
-1, 7))[range(preds_box.reshape(-1, 7).shape[0]),
range(cur_gt.reshape(-1, 7).shape[0])]
iou_targets[torch.isnan(iou_targets)] = 0
iou_targets = 2 * iou_targets - 1
iou_loss = self.crit_iou(preds_dict['iou'].reshape(-1),
iou_targets) * mask.reshape(-1)
iou_loss = iou_loss.sum() / (mask.sum() + 1e-4)
losses.update(
{f'{task_id}_iou_loss': iou_loss * self.iou_weight})
losses.update({
f'{task_id}_hm_loss': hm_loss,
f'{task_id}_loc_loss': loc_loss * self.weight
})
return losses
def predict(self, preds_dicts, batch_input_metas, **kwargs):
"""decode, nms, then return the detection result.
Additionally support double flip testing
"""
rets = []
post_center_range = self.test_cfg.post_center_limit_range
if len(post_center_range) > 0:
post_center_range = torch.tensor(
post_center_range,
dtype=preds_dicts[0]['scores'].dtype,
device=preds_dicts[0]['scores'].device,
)
for task_id, preds_dict in enumerate(preds_dicts):
# convert B C N to B N C
for key, val in preds_dict.items():
if torch.is_tensor(preds_dict[key]):
if len(preds_dict[key].shape) == 3:
preds_dict[key] = val.permute(0, 2, 1).contiguous()
batch_score = preds_dict['scores']
batch_label = preds_dict['labels']
batch_mask = preds_dict['mask']
if self.use_iou_loss:
batch_iou = preds_dict['iou'].squeeze(2)
else:
batch_iou = None
batch_dim = torch.exp(preds_dict['dim'])
batch_rots = preds_dict['rot'][..., 0:1]
batch_rotc = preds_dict['rot'][..., 1:2]
batch_reg = preds_dict['reg']
batch_hei = preds_dict['height']
batch_rot = torch.atan2(batch_rots, batch_rotc)
if self.use_iou_loss:
batch_iou = (batch_iou + 1) * 0.5
batch_iou = torch.clamp(batch_iou, min=0.0, max=1.0)
batch, _, H, W = preds_dict['hm'].size()
ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_score)
xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_score)
obj_num = preds_dict['order'].shape[1]
batch_id = np.indices((batch, obj_num))[0]
batch_id = torch.from_numpy(batch_id).to(preds_dict['order'])
xs = (
xs.view(batch, -1, 1)[batch_id, preds_dict['order']] +
batch_reg[:, :, 0:1])
ys = (
ys.view(batch, -1, 1)[batch_id, preds_dict['order']] +
batch_reg[:, :, 1:2])
xs = (
xs * self.test_cfg.out_size_factor *
self.test_cfg.voxel_size[0] + self.test_cfg.pc_range[0])
ys = (
ys * self.test_cfg.out_size_factor *
self.test_cfg.voxel_size[1] + self.test_cfg.pc_range[1])
if 'vel' in preds_dict:
batch_vel = preds_dict['vel']
batch_box_preds = torch.cat(
[xs, ys, batch_hei, batch_dim, batch_vel, batch_rot],
dim=2)
else:
batch_box_preds = torch.cat(
[xs, ys, batch_hei, batch_dim, batch_rot], dim=2)
if self.test_cfg.get('per_class_nms', False):
pass
else:
rets.append(
self.post_processing(
batch_input_metas,
batch_box_preds,
batch_score,
batch_label,
self.test_cfg,
post_center_range,
task_id,
batch_mask,
batch_iou,
))
# Merge branches results
ret_list = []
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
temp_instances = InstanceData()
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = batch_input_metas[i]['box_type_3d'](
bboxes, self.bbox_code_size)
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k] for ret in rets])
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
temp_instances.bboxes_3d = bboxes
temp_instances.scores_3d = scores
temp_instances.labels_3d = labels
ret_list.append(temp_instances)
return ret_list
def post_processing(
self,
img_metas,
batch_box_preds,
batch_score,
batch_label,
test_cfg,
post_center_range,
task_id,
batch_mask,
batch_iou,
):
batch_size = len(batch_score)
prediction_dicts = []
for i in range(batch_size):
box_preds = batch_box_preds[i]
scores = batch_score[i]
labels = batch_label[i]
mask = batch_mask[i]
distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(
1) & (box_preds[..., :3] <= post_center_range[3:]).all(1)
mask = mask & distance_mask
box_preds = box_preds[mask]
scores = scores[mask]
labels = labels[mask]
if self.use_iou_loss:
iou_factor = torch.LongTensor(self.iou_factor).to(labels)
ious = batch_iou[i][mask]
ious = torch.pow(ious, iou_factor[labels])
scores = scores * ious
boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]
if test_cfg.get('circular_nms', False):
centers = boxes_for_nms[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
selected = _circle_nms(
boxes,
min_radius=test_cfg.min_radius[task_id],
post_max_size=test_cfg.nms.nms_post_max_size,
)
elif test_cfg.nms.get('use_multi_class_nms', False):
# multi class nms
selected = []
for c in range(3):
class_mask = labels == c
if class_mask.sum() > 0:
class_idx = class_mask.nonzero()
select = nms_iou3d(
boxes_for_nms[class_mask].float(),
scores[class_mask].float(),
thresh=test_cfg.nms.nms_iou_threshold[c],
pre_maxsize=test_cfg.nms.nms_pre_max_size[c],
post_max_size=test_cfg.nms.nms_post_max_size[c],
)
selected.append(class_idx[select, 0])
if len(selected) > 0:
selected = torch.cat(selected, dim=0)
else:
selected = nms_bev(
boxes_for_nms.float(),
scores.float(),
thresh=test_cfg.nms.nms_iou_threshold,
pre_max_size=test_cfg.nms.nms_pre_max_size,
post_max_size=test_cfg.nms.nms_post_max_size,
)
selected_boxes = box_preds[selected]
selected_scores = scores[selected]
selected_labels = labels[selected]
prediction_dict = {
'bboxes': selected_boxes,
'scores': selected_scores,
'labels': selected_labels,
}
prediction_dicts.append(prediction_dict)
return prediction_dicts
def _circle_nms(boxes, min_radius, post_max_size=83):
"""NMS according to center distance."""
keep = np.array(circle_nms(boxes.cpu().numpy(),
thresh=min_radius))[:post_max_size]
keep = torch.from_numpy(keep).long().to(boxes.device)
return keep
def get_box(pred_boxs, order, test_cfg, H, W):
batch = pred_boxs.shape[0]
obj_num = order.shape[1]
ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
ys = ys.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
xs = xs.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
batch_id = np.indices((batch, obj_num))[0]
batch_id = torch.from_numpy(batch_id).to(order)
xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:,
0:1]
ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:,
1:2]
xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[
0] + test_cfg.pc_range[0]
ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[
1] + test_cfg.pc_range[1]
rot = torch.atan2(pred_boxs[:, 6:7], pred_boxs[:, 7:8])
pred = torch.cat(
[xs, ys, pred_boxs[:, 2:3],
torch.exp(pred_boxs[:, 3:6]), rot], dim=1)
return torch.transpose(pred, 1, 2).contiguous() # B M 7
def get_box_gt(gt_boxs, order, test_cfg, H, W):
batch = gt_boxs.shape[0]
obj_num = order.shape[1]
ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
ys = ys.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs)
xs = xs.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs)
batch_id = np.indices((batch, obj_num))[0]
batch_id = torch.from_numpy(batch_id).to(order)
batch_gt_dim = torch.exp(gt_boxs[..., 3:6])
batch_gt_hei = gt_boxs[..., 2:3]
batch_gt_rot = torch.atan2(gt_boxs[..., -2:-1], gt_boxs[..., -1:])
xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[...,
0:1]
ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[...,
1:2]
xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[
0] + test_cfg.pc_range[0]
ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[
1] + test_cfg.pc_range[1]
batch_box_targets = torch.cat(
[xs, ys, batch_gt_hei, batch_gt_dim, batch_gt_rot], dim=-1)
return batch_box_targets # B M 7
def get_corresponding_box(x_ind, y_ind, y_mask, y_cls, target_box):
# find the id in y which has the same ind in x
select_target = torch.zeros(x_ind.shape[0], x_ind.shape[1],
target_box.shape[2]).to(target_box)
select_mask = torch.zeros_like(x_ind).to(y_mask)
select_cls = torch.zeros_like(x_ind).to(y_cls)
for i in range(x_ind.shape[0]):
idx = torch.arange(y_ind[i].shape[-1]).to(x_ind)
idx = idx[y_mask[i]]
box_cls = y_cls[i][y_mask[i]]
valid_y_ind = y_ind[i][y_mask[i]]
match = (x_ind[i].unsqueeze(1) == valid_y_ind.unsqueeze(0)).nonzero()
select_target[i, match[:, 0]] = target_box[i, idx[match[:, 1]]]
select_mask[i, match[:, 0]] = 1
select_cls[i, match[:, 0]] = box_cls[match[:, 1]]
return select_target, select_mask, select_cls
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/losses/centernet_loss.py # noqa
import torch
from torch import nn
from mmdet3d.registry import MODELS
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _transpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
@MODELS.register_module()
class FastFocalLoss(nn.Module):
"""Reimplemented focal loss, exactly the same as the CornerNet version.
Faster and costs much less memory.
"""
def __init__(self, focal_factor=2):
super(FastFocalLoss, self).__init__()
self.focal_factor = focal_factor
def forward(self, out, target, ind, mask, cat):
'''
Args:
out, target: B x C x H x W
ind, mask: B x M
cat (category id for peaks): B x M
'''
mask = mask.float()
gt = torch.pow(1 - target, 4)
neg_loss = torch.log(1 - out) * torch.pow(out, self.focal_factor) * gt
neg_loss = neg_loss.sum()
pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
num_pos = mask.sum()
pos_loss = torch.log(pos_pred) * torch.pow(
1 - pos_pred, self.focal_factor) * mask.unsqueeze(2)
pos_loss = pos_loss.sum()
if num_pos == 0:
return -neg_loss
return -(pos_loss + neg_loss) / num_pos
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/ops/modules/ms_deform_attn.py # noqa
import math
from typing import Optional
import torch
import torch.nn.functional as F
from mmcv.utils import ext_loader
from torch import Tensor, nn
from torch.autograd.function import Function, once_differentiable
from torch.nn.init import constant_, xavier_uniform_
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor,
value_level_start_index: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU/MLU version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (torch.Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (torch.Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (torch.Tensor): The weight of sampling points
used when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (torch.Tensor): The step used in image to column.
Returns:
torch.Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU/MLU version of backward function.
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
Returns:
tuple[Tensor]: Gradient of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
class MSDeformAttn(nn.Module):
"""Multi-Scale Deformable Attention Module. Note that the difference
between this implementation and the implementation in MMCV is that the
dimension of input and hidden embedding in the multi-attention-head can be
specified respectively.
Args:
dim_model (int, optional): The input and output dimension in the model.
Defaults to 256.
dim_single_head (int, optional): hidden dimension in the single head.
Defaults to 64.
n_levels (int, optional): number of feature levels. Defaults to 4.
n_heads (int, optional): number of attention heads. Defaults to 8.
n_points (int, optional): number of sampling points per attention head
per feature level. Defaults to 4.
out_sample_loc (bool, optional): Whether to return the sampling
location. Defaults to False.
"""
def __init__(self,
dim_model=256,
dim_single_head=64,
n_levels=4,
n_heads=8,
n_points=4,
out_sample_loc=False):
super().__init__()
self.im2col_step = 64
self.dim_model = dim_model
self.dim_single_head = dim_single_head
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.out_sample_loc = out_sample_loc
self.sampling_offsets = nn.Linear(dim_model,
n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(dim_model,
n_heads * n_levels * n_points)
self.value_proj = nn.Linear(dim_model, dim_single_head * n_heads)
self.output_proj = nn.Linear(dim_single_head * n_heads, dim_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(
self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.n_heads, 1, 1, 2).repeat(1, self.n_levels,
self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self,
query: Tensor,
reference_points: Tensor,
input_flatten: Tensor,
input_spatial_shapes: Tensor,
input_level_start_index: Tensor,
input_padding_mask: Optional[Tensor] = None):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): (N, num_query, C)
reference_points (Tensor): (N, num_query, n_levels, 2). The
normalized reference points with shape
(bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
input_flatten (Tensor): _description_
input_spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
input_level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
input_padding_mask (Optional[Tensor], optional): The padding mask
for value. Defaults to None.
Returns:
Tuple[Tensor, Tensor]: forwarded results.
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] *
input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.dim_single_head)
sampling_offsets = self.sampling_offsets(query).view(
N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(
N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights,
-1).view(N, Len_q, self.n_heads,
self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]],
-1).to(sampling_offsets)
sampling_locations = reference_points[:, :, None, :, None, :] + \
sampling_offsets / offset_normalizer[None, None, None, :, None, :] # noqa: E501
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 # noqa: E501
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.' # noqa: E501
.format(reference_points.shape[-1]))
output = MultiScaleDeformableAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index,
sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
if self.out_sample_loc:
return output, torch.cat(
(sampling_locations, attention_weights[:, :, :, :, :, None]),
dim=-1)
else:
return output, None
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/utils/transformer.py # noqa
import torch
from einops import rearrange
from mmcv.cnn.bricks.activation import GELU
from torch import einsum, nn
from .multi_scale_deform_attn import MSDeformAttn
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, y=None, **kwargs):
if y is not None:
return self.fn(self.norm(x), self.norm(y), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
class FFN(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
n_heads=8,
dim_single_head=64,
dropout=0.0,
out_attention=False):
super().__init__()
inner_dim = dim_single_head * n_heads
project_out = not (n_heads == 1 and dim_single_head == dim)
self.n_heads = n_heads
self.scale = dim_single_head**-0.5
self.out_attention = out_attention
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out else nn.Identity())
def forward(self, x):
_, _, _, h = *x.shape, self.n_heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
if self.out_attention:
return self.to_out(out), attn
else:
return self.to_out(out)
class DeformableCrossAttention(nn.Module):
def __init__(
self,
dim_model=256,
dim_single_head=64,
dropout=0.3,
n_levels=3,
n_heads=6,
n_points=9,
out_sample_loc=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
dim_model,
dim_single_head,
n_levels,
n_heads,
n_points,
out_sample_loc=out_sample_loc)
self.dropout = nn.Dropout(dropout)
self.out_sample_loc = out_sample_loc
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
src,
query_pos=None,
reference_points=None,
src_spatial_shapes=None,
level_start_index=None,
src_padding_mask=None,
):
# cross attention
tgt2, sampling_locations = self.cross_attn(
self.with_pos_embed(tgt, query_pos),
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask,
)
tgt = self.dropout(tgt2)
if self.out_sample_loc:
return tgt, sampling_locations
else:
return tgt
class DeformableTransformerDecoder(nn.Module):
"""Deformable transformer decoder.
Note that the ``DeformableDetrTransformerDecoder`` in MMDet has different
interfaces in multi-head-attention which is customized here. For example,
'embed_dims' is not a position argument in our customized multi-head-self-
attention, but is required in MMDet. Thus, we can not directly use the
``DeformableDetrTransformerDecoder`` in MMDET.
"""
def __init__(
self,
dim,
n_levels=3,
depth=2,
n_heads=4,
dim_single_head=32,
dim_ffn=256,
dropout=0.0,
out_attention=False,
n_points=9,
):
super().__init__()
self.out_attention = out_attention
self.layers = nn.ModuleList([])
self.depth = depth
self.n_levels = n_levels
self.n_points = n_points
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PreNorm(
dim,
SelfAttention(
dim,
n_heads=n_heads,
dim_single_head=dim_single_head,
dropout=dropout,
out_attention=self.out_attention,
),
),
PreNorm(
dim,
DeformableCrossAttention(
dim,
dim_single_head,
n_levels=n_levels,
n_heads=n_heads,
dropout=dropout,
n_points=n_points,
out_sample_loc=self.out_attention,
),
),
PreNorm(dim, FFN(dim, dim_ffn, dropout=dropout)),
]))
def forward(self, x, pos_embedding, src, src_spatial_shapes,
level_start_index, center_pos):
if self.out_attention:
out_cross_attention_list = []
if pos_embedding is not None:
center_pos_embedding = pos_embedding(center_pos)
reference_points = center_pos[:, :,
None, :].repeat(1, 1, self.n_levels, 1)
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
if self.out_attention:
if center_pos_embedding is not None:
x_att, self_att = self_attn(x + center_pos_embedding)
x = x_att + x
x_att, cross_att = cross_attn(
x,
src,
query_pos=center_pos_embedding,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
else:
x_att, self_att = self_attn(x)
x = x_att + x
x_att, cross_att = cross_attn(
x,
src,
query_pos=None,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
out_cross_attention_list.append(cross_att)
else:
if center_pos_embedding is not None:
x_att = self_attn(x + center_pos_embedding)
x = x_att + x
x_att = cross_attn(
x,
src,
query_pos=center_pos_embedding,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
else:
x_att = self_attn(x)
x = x_att + x
x_att = cross_attn(
x,
src,
query_pos=None,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
x = x_att + x
x = ff(x) + x
out_dict = {'ct_feat': x}
if self.out_attention:
out_dict.update({
'out_attention':
torch.stack(out_cross_attention_list, dim=2)
})
return out_dict
_base_ = ['mmdet3d::_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False)
# model settings
# Voxel size for voxel encoder
# Usually voxel size is changed consistently with the point cloud range
# If point cloud range is modified, do remember to change all related
# keys in the config.
voxel_size = [0.1, 0.1, 0.15]
point_cloud_range = [-75.2, -75.2, -2, 75.2, 75.2, 4]
class_names = ['Car', 'Pedestrian', 'Cyclist']
tasks = [dict(num_class=3, class_names=['car', 'pedestrian', 'cyclist'])]
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=False)
file_client_args = dict(backend='disk')
model = dict(
type='CenterFormer',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='dynamic',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(-1, -1))),
voxel_encoder=dict(
type='DynamicSimpleVFE',
point_cloud_range=point_cloud_range,
voxel_size=voxel_size),
middle_encoder=dict(
type='SparseEncoder',
in_channels=5,
sparse_shape=[41, 1504, 1504],
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='naiveSyncBN1d', eps=0.001, momentum=0.01),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, [0, 1, 1]), (1, 1)),
block_type='basicblock'),
backbone=dict(
type='DeformableDecoderRPN',
layer_nums=[5, 5, 1],
ds_num_filters=[256, 256, 128],
num_input_features=256,
tasks=tasks,
use_gt_training=True,
corner=True,
assign_label_window_size=1,
obj_num=500,
norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01),
transformer_config=dict(
depth=2,
n_heads=6,
dim_single_head=64,
dim_ffn=256,
dropout=0.3,
out_attn=False,
n_points=15,
),
),
bbox_head=dict(
type='CenterFormerBboxHead',
in_channels=256,
tasks=tasks,
dataset='waymo',
weight=2,
corner_loss=True,
iou_loss=True,
assign_label_window_size=1,
norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01),
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
common_heads={
'reg': (2, 2),
'height': (1, 2),
'dim': (3, 2),
'rot': (2, 2),
'iou': (1, 2)
}, # (output_channel, num_conv)
),
train_cfg=dict(
grid_size=[1504, 1504, 40],
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
point_cloud_range=point_cloud_range,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict(
post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
nms=dict(
use_rotate_nms=False,
use_multi_class_nms=True,
nms_pre_max_size=[1600, 1600, 800],
nms_post_max_size=[200, 200, 100],
nms_iou_threshold=[0.8, 0.55, 0.55],
),
score_threshold=0.1,
pc_range=[-75.2, -75.2],
out_size_factor=4,
voxel_size=[0.1, 0.1],
obj_num=1000,
))
data_root = 'data/waymo/kitti_format/'
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4]))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
file_client_args=file_client_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
file_client_args=file_client_args))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
ann_file='waymo_infos_val.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
metainfo=metainfo,
box_type_3d='LiDAR',
file_client_args=file_client_args))
test_dataloader = val_dataloader
val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
file_client_args=file_client_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# For waymo dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 20. Please change the interval accordingly if you do not
# use a default schedule.
# optimizer
lr = 3e-4
# This schedule is mainly used by models on nuScenes dataset
# max_norm=10 is better for SECOND
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=35, norm_type=2))
# learning rate
param_scheduler = [
# learning rate scheduler
# During the first 8 epochs, learning rate increases from 0 to lr * 10
# during the next 12 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=8,
eta_min=lr * 10,
begin=0,
end=8,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=12,
eta_min=lr * 1e-4,
begin=8,
end=20,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 8 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=8,
eta_min=0.85 / 0.95,
begin=0,
end=8,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=12,
eta_min=1,
begin=8,
end=20,
by_epoch=True,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=20)
val_cfg = dict()
test_cfg = dict()
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (4 GPUs) x (4 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
custom_hooks = [dict(type='DisableObjectSampleHook', disable_after_epoch=15)]
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
from mmdet3d.datasets.transforms import ObjectSample
from mmdet3d.engine.hooks import DisableObjectSampleHook
class TestDisableObjectSampleHook(TestCase):
runner = Mock()
runner.train_dataloader = Mock()
runner.train_dataloader.dataset = Mock()
runner.train_dataloader.dataset.pipeline = Mock()
runner.train_dataloader._DataLoader__initialized = True
runner.train_dataloader.dataset.pipeline.transforms = [
ObjectSample(
db_sampler=dict(
data_root='tests/data/waymo/kitti_format',
info_path= # noqa
'tests/data/waymo/kitti_format/waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5)),
classes=['Car'],
sample_groups=dict(Car=15),
))
]
def test_is_model_wrapper_and_persistent_workers_on(self):
self.runner.train_dataloader.dataset.pipeline.transforms[
0].disabled = False
self.runner.train_dataloader.persistent_workers = True
hook = DisableObjectSampleHook(disable_after_epoch=15)
self.runner.epoch = 14
hook.before_train_epoch(self.runner)
self.assertFalse(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.runner.epoch = 15
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.assertTrue(hook._restart_dataloader)
self.assertFalse(self.runner.train_dataloader._DataLoader__initialized)
self.runner.epoch = 16
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
def test_not_model_wrapper_and_persistent_workers_off(self):
self.runner.train_dataloader.dataset.pipeline.transforms[
0].disabled = False
self.runner.train_dataloader.persistent_workers = False
hook = DisableObjectSampleHook(disable_after_epoch=15)
self.runner.epoch = 14
hook.before_train_epoch(self.runner)
self.assertFalse(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.runner.epoch = 15
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.assertFalse(hook._restart_dataloader)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.runner.epoch = 16
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment