Unverified Commit ac0302ed authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

Support CenterFormer in `projects` (#2173)

* init centerformer in projects

* add readme and disable_tf32_switch

* using our iou3d and nms3d, using basicblock in mmdet, simplify code in projects

* remove attention.py and sparse_block.py

* only using single fold

* polish code

* polish code and add dosstring

* add ut for disable_object_sample_hook

* modify data_root

* add ut

* update readme

* polish code

* fix docstring

* resolve comments

* modify project names

* modify project names and add _forward

* fix docstring

* remove disable_tf32
parent edc468bf
......@@ -67,11 +67,10 @@ def init_model(config: Union[str, Path, Config],
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = dataset_meta
model.dataset_meta = checkpoint['meta']['dataset_meta']
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
......
......@@ -532,6 +532,8 @@ class LoadPointsFromFile(BaseTransform):
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
norm_intensity (bool): Whether to normlize the intensity. Defaults to
False.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
......@@ -544,6 +546,7 @@ class LoadPointsFromFile(BaseTransform):
use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False,
use_color: bool = False,
norm_intensity: bool = False,
file_client_args: dict = dict(backend='disk')
) -> None:
self.shift_height = shift_height
......@@ -557,6 +560,7 @@ class LoadPointsFromFile(BaseTransform):
self.coord_type = coord_type
self.load_dim = load_dim
self.use_dim = use_dim
self.norm_intensity = norm_intensity
self.file_client_args = file_client_args.copy()
self.file_client = None
......@@ -599,6 +603,10 @@ class LoadPointsFromFile(BaseTransform):
points = self._load_points(pts_file_path)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
attribute_dims = None
if self.shift_height:
......
......@@ -359,6 +359,7 @@ class ObjectSample(BaseTransform):
db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = TRANSFORMS.build(db_sampler)
self.use_ground_plane = use_ground_plane
self.disabled = False
@staticmethod
def remove_points_in_boxes(points: BasePoints,
......@@ -387,6 +388,9 @@ class ObjectSample(BaseTransform):
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
"""
if self.disabled:
return input_dict
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
......
# Copyright (c) OpenMMLab. All rights reserved.
from .benchmark_hook import BenchmarkHook
from .disable_object_sample_hook import DisableObjectSampleHook
from .visualization_hook import Det3DVisualizationHook
__all__ = ['Det3DVisualizationHook', 'BenchmarkHook']
__all__ = [
'Det3DVisualizationHook', 'BenchmarkHook', 'DisableObjectSampleHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmdet3d.datasets.transforms import ObjectSample
from mmdet3d.registry import HOOKS
@HOOKS.register_module()
class DisableObjectSampleHook(Hook):
"""The hook of disabling augmentations during training.
Args:
disable_after_epoch (int): The number of epochs after which
the ``ObjectSample`` will be closed in the training.
Defaults to 15.
"""
def __init__(self, disable_after_epoch: int = 15):
self.disable_after_epoch = disable_after_epoch
self._restart_dataloader = False
def before_train_epoch(self, runner: Runner):
"""Close augmentation.
Args:
runner (Runner): The runner.
"""
epoch = runner.epoch
train_loader = runner.train_dataloader
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
if epoch == self.disable_after_epoch:
runner.logger.info('Disable ObjectSample')
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
if isinstance(transform, ObjectSample):
assert hasattr(transform, 'disabled')
transform.disabled = True
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
......@@ -6,9 +6,7 @@ try:
except ImportError:
IS_SPCONV2_AVAILABLE = False
else:
if hasattr(spconv,
'__version__') and spconv.__version__ >= '2.0.0' and hasattr(
spconv, 'pytorch'):
if hasattr(spconv, '__version__') and spconv.__version__ >= '2.0.0':
IS_SPCONV2_AVAILABLE = register_spconv2()
else:
IS_SPCONV2_AVAILABLE = False
......
# CenterFormer: Center-based Transformer for 3D Object Detection
> [CenterFormer: Center-based Transformer for 3D Object Detection](https://arxiv.org/abs/2209.05588)
<!-- [ALGORITHM] -->
## Abstract
Query-based transformer has shown great potential in con-
structing long-range attention in many image-domain tasks, but has
rarely been considered in LiDAR-based 3D object detection due to the
overwhelming size of the point cloud data. In this paper, we propose
CenterFormer, a center-based transformer network for 3D object de-
tection. CenterFormer first uses a center heatmap to select center candi-
dates on top of a standard voxel-based point cloud encoder. It then uses
the feature of the center candidate as the query embedding in the trans-
former. To further aggregate features from multiple frames, we design
an approach to fuse features through cross-attention. Lastly, regression
heads are added to predict the bounding box on the output center feature
representation. Our design reduces the convergence difficulty and compu-
tational complexity of the transformer structure. The results show signif-
icant improvements over the strong baseline of anchor-free object detec-
tion networks. CenterFormer achieves state-of-the-art performance for a
single model on the Waymo Open Dataset, with 73.7% mAPH on the val-
idation set and 75.6% mAPH on the test set, significantly outperforming
all previously published CNN and transformer-based methods. Our code
is publicly available at https://github.com/TuSimple/centerformer
<div align=center>
<img src="https://user-images.githubusercontent.com/34888372/209500088-b707d7cd-d4d5-4f20-8fdf-a2c7ad15df34.png" width="800"/>
</div>
## Introduction
We implement CenterFormer and provide the result and checkpoints on Waymo dataset.
We follow the below style to name config files. Contributors are advised to follow the same style.
`{xxx}` is required field and `[yyy]` is optional.
`{model}`: model type like `centerpoint`.
`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`.
`{backbone}`: backbone type like `second`.
`{neck}`: neck type like `secfpn`.
`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default.
`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs.
`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively.
## Usage
<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->
### Training commands
In MMDetection3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```
For multi-gpu training, run:
```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```
### Testing commands
In MMDetection3D's root directory, run the following command to test the model:
```bash
python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH}
```
## Results and models
### Waymo
| Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :----------------------------------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [SECFPN_WithAttention](./configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py) | 5 | voxel (0.1) | ✓ | × | 14.8 | | 72.2 | 69.5 | 65.9 | 63.3 | [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/centerformer/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class_20221227_205613-70c9ad37.json) |
**Note** that `SECFPN_WithAttention` denotes both SECOND and SECONDFPN with ChannelAttention and SpatialAttention.
## Citation
```latex
@InProceedings{Zhou_centerformer,
title = {CenterFormer: Center-based Transformer for 3D Object Detection},
author = {Zhou, Zixiang and Zhao, Xiangchen and Wang, Yu and Wang, Panqu and Foroosh, Hassan},
booktitle = {ECCV},
year = {2022}
}
```
from .bbox_ops import nms_iou3d
from .centerformer import CenterFormer
from .centerformer_backbone import (DeformableDecoderRPN,
MultiFrameDeformableDecoderRPN)
from .centerformer_head import CenterFormerBboxHead
from .losses import FastFocalLoss
__all__ = [
'CenterFormer', 'DeformableDecoderRPN', 'CenterFormerBboxHead',
'FastFocalLoss', 'nms_iou3d', 'MultiFrameDeformableDecoderRPN'
]
import torch
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['iou3d_nms3d_forward'])
def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
"""NMS function GPU implementation (using IoU3D). The difference between
this implementation and nms3d in MMCV is that we add `pre_maxsize` and
`post_max_size` before and after NMS respectively.
Args:
boxes (Tensor): Input boxes with the shape of [N, 7]
([cx, cy, cz, l, w, h, theta]).
scores (Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Defaults to None.
post_max_size (int, optional): Max size of boxes after NMS.
Defaults to None.
Returns:
Tensor: Indexes after NMS.
"""
# TODO: directly refactor ``nms3d`` in MMCV
assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)'
order = scores.sort(0, descending=True)[1]
if pre_maxsize is not None:
order = order[:pre_maxsize]
boxes = boxes[order].contiguous()
keep = boxes.new_zeros(boxes.size(0), dtype=torch.long)
num_out = boxes.new_zeros(size=(), dtype=torch.long)
ext_module.iou3d_nms3d_forward(
boxes, keep, num_out, nms_overlap_thresh=thresh)
keep = order[keep[:num_out].to(boxes.device)].contiguous()
if post_max_size is not None:
keep = keep[:post_max_size]
return keep
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
@MODELS.register_module()
class CenterFormer(Base3DDetector):
"""Base class of center-based 3D detector.
Args:
voxel_encoder (dict, optional): Point voxelization
encoder layer. Defaults to None.
middle_encoder (dict, optional): Middle encoder layer
of points cloud modality. Defaults to None.
pts_fusion_layer (dict, optional): Fusion layer.
Defaults to None.
backbone (dict, optional): Backbone of extracting
points features. Defaults to None.
neck (dict, optional): Neck of extracting
points features. Defaults to None.
bbox_head (dict, optional): Bboxes head of
point cloud modality. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: Optional[dict] = None,
middle_encoder: Optional[dict] = None,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs):
super(CenterFormer, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs)
if voxel_encoder:
self.voxel_encoder = MODELS.build(voxel_encoder)
if middle_encoder:
self.middle_encoder = MODELS.build(middle_encoder)
if backbone:
backbone.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
if bbox_head:
bbox_head.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def init_weights(self):
for m in self.modules():
if isinstance(m, _BatchNorm):
torch.nn.init.uniform_(m.weight)
@property
def with_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'backbone') and self.backbone is not None
@property
def with_voxel_encoder(self):
"""bool: Whether the detector has a voxel encoder."""
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
"""bool: Whether the detector has a middle encoder."""
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def _forward(self):
pass
def extract_feat(self, batch_inputs_dict: dict,
batch_input_metas: List[dict]) -> tuple:
"""Extract features from images and points.
Args:
batch_inputs_dict (dict): Dict of batch inputs. It
contains
- points (List[tensor]): Point cloud of multiple inputs.
- imgs (tensor): Image tensor with shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of multiple inputs
in a batch.
Returns:
tuple: Two elements in tuple arrange as
image features and point cloud features.
"""
voxel_dict = batch_inputs_dict.get('voxels', None)
voxel_features, feature_coors = self.voxel_encoder(
voxel_dict['voxels'], voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
x = self.middle_encoder(voxel_features, feature_coors, batch_size)
return x
def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' and `imgs` keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Tensor of batch images, has shape
(B, C, H ,W)
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
preds, batch_tatgets = self.backbone(pts_feats, batch_data_samples)
preds = self.bbox_head(preds)
losses = dict()
losses.update(self.bbox_head.loss(preds, batch_tatgets))
return losses
# return self.bbox_head.predict(preds, batch_tatgets)
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
preds, _ = self.backbone(pts_feats, batch_data_samples)
preds = self.bbox_head(preds)
results_list_3d = self.bbox_head.predict(preds, batch_input_metas)
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return detsamples
This diff is collapsed.
This diff is collapsed.
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/losses/centernet_loss.py # noqa
import torch
from torch import nn
from mmdet3d.registry import MODELS
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _transpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
@MODELS.register_module()
class FastFocalLoss(nn.Module):
"""Reimplemented focal loss, exactly the same as the CornerNet version.
Faster and costs much less memory.
"""
def __init__(self, focal_factor=2):
super(FastFocalLoss, self).__init__()
self.focal_factor = focal_factor
def forward(self, out, target, ind, mask, cat):
'''
Args:
out, target: B x C x H x W
ind, mask: B x M
cat (category id for peaks): B x M
'''
mask = mask.float()
gt = torch.pow(1 - target, 4)
neg_loss = torch.log(1 - out) * torch.pow(out, self.focal_factor) * gt
neg_loss = neg_loss.sum()
pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
num_pos = mask.sum()
pos_loss = torch.log(pos_pred) * torch.pow(
1 - pos_pred, self.focal_factor) * mask.unsqueeze(2)
pos_loss = pos_loss.sum()
if num_pos == 0:
return -neg_loss
return -(pos_loss + neg_loss) / num_pos
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/ops/modules/ms_deform_attn.py # noqa
import math
from typing import Optional
import torch
import torch.nn.functional as F
from mmcv.utils import ext_loader
from torch import Tensor, nn
from torch.autograd.function import Function, once_differentiable
from torch.nn.init import constant_, xavier_uniform_
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor,
value_level_start_index: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU/MLU version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (torch.Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (torch.Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (torch.Tensor): The weight of sampling points
used when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (torch.Tensor): The step used in image to column.
Returns:
torch.Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU/MLU version of backward function.
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
Returns:
tuple[Tensor]: Gradient of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
class MSDeformAttn(nn.Module):
"""Multi-Scale Deformable Attention Module. Note that the difference
between this implementation and the implementation in MMCV is that the
dimension of input and hidden embedding in the multi-attention-head can be
specified respectively.
Args:
dim_model (int, optional): The input and output dimension in the model.
Defaults to 256.
dim_single_head (int, optional): hidden dimension in the single head.
Defaults to 64.
n_levels (int, optional): number of feature levels. Defaults to 4.
n_heads (int, optional): number of attention heads. Defaults to 8.
n_points (int, optional): number of sampling points per attention head
per feature level. Defaults to 4.
out_sample_loc (bool, optional): Whether to return the sampling
location. Defaults to False.
"""
def __init__(self,
dim_model=256,
dim_single_head=64,
n_levels=4,
n_heads=8,
n_points=4,
out_sample_loc=False):
super().__init__()
self.im2col_step = 64
self.dim_model = dim_model
self.dim_single_head = dim_single_head
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.out_sample_loc = out_sample_loc
self.sampling_offsets = nn.Linear(dim_model,
n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(dim_model,
n_heads * n_levels * n_points)
self.value_proj = nn.Linear(dim_model, dim_single_head * n_heads)
self.output_proj = nn.Linear(dim_single_head * n_heads, dim_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(
self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.n_heads, 1, 1, 2).repeat(1, self.n_levels,
self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self,
query: Tensor,
reference_points: Tensor,
input_flatten: Tensor,
input_spatial_shapes: Tensor,
input_level_start_index: Tensor,
input_padding_mask: Optional[Tensor] = None):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): (N, num_query, C)
reference_points (Tensor): (N, num_query, n_levels, 2). The
normalized reference points with shape
(bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
input_flatten (Tensor): _description_
input_spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
input_level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
input_padding_mask (Optional[Tensor], optional): The padding mask
for value. Defaults to None.
Returns:
Tuple[Tensor, Tensor]: forwarded results.
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] *
input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.dim_single_head)
sampling_offsets = self.sampling_offsets(query).view(
N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(
N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights,
-1).view(N, Len_q, self.n_heads,
self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]],
-1).to(sampling_offsets)
sampling_locations = reference_points[:, :, None, :, None, :] + \
sampling_offsets / offset_normalizer[None, None, None, :, None, :] # noqa: E501
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 # noqa: E501
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.' # noqa: E501
.format(reference_points.shape[-1]))
output = MultiScaleDeformableAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index,
sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
if self.out_sample_loc:
return output, torch.cat(
(sampling_locations, attention_weights[:, :, :, :, :, None]),
dim=-1)
else:
return output, None
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/utils/transformer.py # noqa
import torch
from einops import rearrange
from mmcv.cnn.bricks.activation import GELU
from torch import einsum, nn
from .multi_scale_deform_attn import MSDeformAttn
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, y=None, **kwargs):
if y is not None:
return self.fn(self.norm(x), self.norm(y), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
class FFN(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
n_heads=8,
dim_single_head=64,
dropout=0.0,
out_attention=False):
super().__init__()
inner_dim = dim_single_head * n_heads
project_out = not (n_heads == 1 and dim_single_head == dim)
self.n_heads = n_heads
self.scale = dim_single_head**-0.5
self.out_attention = out_attention
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out else nn.Identity())
def forward(self, x):
_, _, _, h = *x.shape, self.n_heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
if self.out_attention:
return self.to_out(out), attn
else:
return self.to_out(out)
class DeformableCrossAttention(nn.Module):
def __init__(
self,
dim_model=256,
dim_single_head=64,
dropout=0.3,
n_levels=3,
n_heads=6,
n_points=9,
out_sample_loc=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
dim_model,
dim_single_head,
n_levels,
n_heads,
n_points,
out_sample_loc=out_sample_loc)
self.dropout = nn.Dropout(dropout)
self.out_sample_loc = out_sample_loc
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
src,
query_pos=None,
reference_points=None,
src_spatial_shapes=None,
level_start_index=None,
src_padding_mask=None,
):
# cross attention
tgt2, sampling_locations = self.cross_attn(
self.with_pos_embed(tgt, query_pos),
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask,
)
tgt = self.dropout(tgt2)
if self.out_sample_loc:
return tgt, sampling_locations
else:
return tgt
class DeformableTransformerDecoder(nn.Module):
"""Deformable transformer decoder.
Note that the ``DeformableDetrTransformerDecoder`` in MMDet has different
interfaces in multi-head-attention which is customized here. For example,
'embed_dims' is not a position argument in our customized multi-head-self-
attention, but is required in MMDet. Thus, we can not directly use the
``DeformableDetrTransformerDecoder`` in MMDET.
"""
def __init__(
self,
dim,
n_levels=3,
depth=2,
n_heads=4,
dim_single_head=32,
dim_ffn=256,
dropout=0.0,
out_attention=False,
n_points=9,
):
super().__init__()
self.out_attention = out_attention
self.layers = nn.ModuleList([])
self.depth = depth
self.n_levels = n_levels
self.n_points = n_points
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PreNorm(
dim,
SelfAttention(
dim,
n_heads=n_heads,
dim_single_head=dim_single_head,
dropout=dropout,
out_attention=self.out_attention,
),
),
PreNorm(
dim,
DeformableCrossAttention(
dim,
dim_single_head,
n_levels=n_levels,
n_heads=n_heads,
dropout=dropout,
n_points=n_points,
out_sample_loc=self.out_attention,
),
),
PreNorm(dim, FFN(dim, dim_ffn, dropout=dropout)),
]))
def forward(self, x, pos_embedding, src, src_spatial_shapes,
level_start_index, center_pos):
if self.out_attention:
out_cross_attention_list = []
if pos_embedding is not None:
center_pos_embedding = pos_embedding(center_pos)
reference_points = center_pos[:, :,
None, :].repeat(1, 1, self.n_levels, 1)
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
if self.out_attention:
if center_pos_embedding is not None:
x_att, self_att = self_attn(x + center_pos_embedding)
x = x_att + x
x_att, cross_att = cross_attn(
x,
src,
query_pos=center_pos_embedding,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
else:
x_att, self_att = self_attn(x)
x = x_att + x
x_att, cross_att = cross_attn(
x,
src,
query_pos=None,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
out_cross_attention_list.append(cross_att)
else:
if center_pos_embedding is not None:
x_att = self_attn(x + center_pos_embedding)
x = x_att + x
x_att = cross_attn(
x,
src,
query_pos=center_pos_embedding,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
else:
x_att = self_attn(x)
x = x_att + x
x_att = cross_attn(
x,
src,
query_pos=None,
reference_points=reference_points,
src_spatial_shapes=src_spatial_shapes,
level_start_index=level_start_index,
)
x = x_att + x
x = ff(x) + x
out_dict = {'ct_feat': x}
if self.out_attention:
out_dict.update({
'out_attention':
torch.stack(out_cross_attention_list, dim=2)
})
return out_dict
_base_ = ['mmdet3d::_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False)
# model settings
# Voxel size for voxel encoder
# Usually voxel size is changed consistently with the point cloud range
# If point cloud range is modified, do remember to change all related
# keys in the config.
voxel_size = [0.1, 0.1, 0.15]
point_cloud_range = [-75.2, -75.2, -2, 75.2, 75.2, 4]
class_names = ['Car', 'Pedestrian', 'Cyclist']
tasks = [dict(num_class=3, class_names=['car', 'pedestrian', 'cyclist'])]
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=False)
file_client_args = dict(backend='disk')
model = dict(
type='CenterFormer',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='dynamic',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(-1, -1))),
voxel_encoder=dict(
type='DynamicSimpleVFE',
point_cloud_range=point_cloud_range,
voxel_size=voxel_size),
middle_encoder=dict(
type='SparseEncoder',
in_channels=5,
sparse_shape=[41, 1504, 1504],
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='naiveSyncBN1d', eps=0.001, momentum=0.01),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, [0, 1, 1]), (1, 1)),
block_type='basicblock'),
backbone=dict(
type='DeformableDecoderRPN',
layer_nums=[5, 5, 1],
ds_num_filters=[256, 256, 128],
num_input_features=256,
tasks=tasks,
use_gt_training=True,
corner=True,
assign_label_window_size=1,
obj_num=500,
norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01),
transformer_config=dict(
depth=2,
n_heads=6,
dim_single_head=64,
dim_ffn=256,
dropout=0.3,
out_attn=False,
n_points=15,
),
),
bbox_head=dict(
type='CenterFormerBboxHead',
in_channels=256,
tasks=tasks,
dataset='waymo',
weight=2,
corner_loss=True,
iou_loss=True,
assign_label_window_size=1,
norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01),
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
common_heads={
'reg': (2, 2),
'height': (1, 2),
'dim': (3, 2),
'rot': (2, 2),
'iou': (1, 2)
}, # (output_channel, num_conv)
),
train_cfg=dict(
grid_size=[1504, 1504, 40],
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
point_cloud_range=point_cloud_range,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict(
post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
nms=dict(
use_rotate_nms=False,
use_multi_class_nms=True,
nms_pre_max_size=[1600, 1600, 800],
nms_post_max_size=[200, 200, 100],
nms_iou_threshold=[0.8, 0.55, 0.55],
),
score_threshold=0.1,
pc_range=[-75.2, -75.2],
out_size_factor=4,
voxel_size=[0.1, 0.1],
obj_num=1000,
))
data_root = 'data/waymo/kitti_format/'
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4]))
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
file_client_args=file_client_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
file_client_args=file_client_args))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
ann_file='waymo_infos_val.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
metainfo=metainfo,
box_type_3d='LiDAR',
file_client_args=file_client_args))
test_dataloader = val_dataloader
val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
file_client_args=file_client_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# For waymo dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 20. Please change the interval accordingly if you do not
# use a default schedule.
# optimizer
lr = 3e-4
# This schedule is mainly used by models on nuScenes dataset
# max_norm=10 is better for SECOND
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=35, norm_type=2))
# learning rate
param_scheduler = [
# learning rate scheduler
# During the first 8 epochs, learning rate increases from 0 to lr * 10
# during the next 12 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=8,
eta_min=lr * 10,
begin=0,
end=8,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=12,
eta_min=lr * 1e-4,
begin=8,
end=20,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 8 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=8,
eta_min=0.85 / 0.95,
begin=0,
end=8,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=12,
eta_min=1,
begin=8,
end=20,
by_epoch=True,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=20)
val_cfg = dict()
test_cfg = dict()
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (4 GPUs) x (4 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
custom_hooks = [dict(type='DisableObjectSampleHook', disable_after_epoch=15)]
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
from mmdet3d.datasets.transforms import ObjectSample
from mmdet3d.engine.hooks import DisableObjectSampleHook
class TestDisableObjectSampleHook(TestCase):
runner = Mock()
runner.train_dataloader = Mock()
runner.train_dataloader.dataset = Mock()
runner.train_dataloader.dataset.pipeline = Mock()
runner.train_dataloader._DataLoader__initialized = True
runner.train_dataloader.dataset.pipeline.transforms = [
ObjectSample(
db_sampler=dict(
data_root='tests/data/waymo/kitti_format',
info_path= # noqa
'tests/data/waymo/kitti_format/waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5)),
classes=['Car'],
sample_groups=dict(Car=15),
))
]
def test_is_model_wrapper_and_persistent_workers_on(self):
self.runner.train_dataloader.dataset.pipeline.transforms[
0].disabled = False
self.runner.train_dataloader.persistent_workers = True
hook = DisableObjectSampleHook(disable_after_epoch=15)
self.runner.epoch = 14
hook.before_train_epoch(self.runner)
self.assertFalse(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.runner.epoch = 15
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.assertTrue(hook._restart_dataloader)
self.assertFalse(self.runner.train_dataloader._DataLoader__initialized)
self.runner.epoch = 16
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
def test_not_model_wrapper_and_persistent_workers_off(self):
self.runner.train_dataloader.dataset.pipeline.transforms[
0].disabled = False
self.runner.train_dataloader.persistent_workers = False
hook = DisableObjectSampleHook(disable_after_epoch=15)
self.runner.epoch = 14
hook.before_train_epoch(self.runner)
self.assertFalse(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.runner.epoch = 15
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
self.assertFalse(hook._restart_dataloader)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.runner.epoch = 16
hook.before_train_epoch(self.runner)
self.assertTrue(self.runner.train_dataloader._DataLoader__initialized)
self.assertTrue(self.runner.train_dataloader.dataset.pipeline.
transforms[0].disabled) # noqa: E501
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment