Commit ff1e5b4e authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Model] Refactor basedetector and singestagedetector and add Det3DDataPreprocessor

parent eca5a9f2
......@@ -2,9 +2,16 @@
from .array_converter import ArrayConverter, array_converter
from .gaussian import (draw_heatmap_gaussian, ellip_gaussian2D, gaussian_2d,
gaussian_radius, get_ellip_gaussian_2D)
from .typing import (ConfigType, ForwardResults, InstanceList, MultiConfig,
OptConfigType, OptInstanceList, OptMultiConfig,
OptSampleList, OptSamplingResultList, SampleList,
SamplingResultList)
__all__ = [
'gaussian_2d', 'gaussian_radius', 'draw_heatmap_gaussian',
'ArrayConverter', 'array_converter', 'ellip_gaussian2D',
'get_ellip_gaussian_2D'
'get_ellip_gaussian_2D', 'ConfigType', 'OptConfigType', 'MultiConfig',
'OptMultiConfig', 'InstanceList', 'OptInstanceList', 'SampleList',
'OptSampleList', 'SamplingResultList', 'ForwardResults',
'OptSamplingResultList'
]
# Copyright (c) OpenMMLab. All rights reserved.
"""Collecting some commonly used type hint in MMDetection3D."""
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.config import ConfigDict
from mmengine.data import InstanceData
from ..bbox.samplers import SamplingResult
from ..data_structures import Det3DDataSample
# Type hint of config data
ConfigType = Union[ConfigDict, dict]
OptConfigType = Optional[ConfigType]
# Type hint of one or more config data
MultiConfig = Union[ConfigType, List[ConfigType]]
OptMultiConfig = Optional[MultiConfig]
InstanceList = List[InstanceData]
OptInstanceList = Optional[InstanceList]
SampleList = List[Det3DDataSample]
OptSampleList = Optional[SampleList]
SamplingResultList = List[SamplingResult]
OptSamplingResultList = Optional[SamplingResultList]
ForwardResults = Union[Dict[str, torch.Tensor], List[Det3DDataSample],
Tuple[torch.Tensor], torch.Tensor]
# Copyright (c) OpenMMLab. All rights reserved.
from .data_preprocessor import Det3DDataPreprocessor
__all__ = ['Det3DDataPreprocessor']
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from mmengine.data import BaseDataElement
from mmengine.model import stack_batch
from mmdet3d.registry import MODELS
from mmdet.models import DetDataPreprocessor
@MODELS.register_module()
class Det3DDataPreprocessor(DetDataPreprocessor):
"""Points (Image) pre-processor for point clouds / multi-modality 3D
detection tasks.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
- Normalize images in inputs with defined std and mean.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
"""
def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
pad_mask: bool = False,
mask_pad_value: int = 0,
pad_seg: bool = False,
seg_pad_value: int = 255,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None):
super().__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
pad_value=pad_value,
pad_mask=pad_mask,
mask_pad_value=mask_pad_value,
pad_seg=pad_seg,
seg_pad_value=seg_pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
batch_augments=batch_augments)
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[Dict, Optional[list]]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[Dict, Optional[list]]: Data in the same format as the
model input.
"""
inputs_dict, batch_data_samples = self.collate_data(data)
if 'points' in inputs_dict[0].keys():
points = [input['points'] for input in inputs_dict]
else:
raise KeyError(
"Model input dict needs to include the 'points' key.")
if 'img' in inputs_dict[0].keys():
imgs = [input['img'] for input in inputs_dict]
# channel transform
if self.channel_conversion:
imgs = [_img[[2, 1, 0], ...] for _img in imgs]
# Normalization.
if self._enable_normalize:
imgs = [(_img - self.mean) / self.std for _img in imgs]
# Pad and stack Tensor.
batch_imgs = stack_batch(imgs, self.pad_size_divisor,
self.pad_value)
batch_pad_shape = self._get_pad_shape(data)
if batch_data_samples is not None:
# NOTE the batched image size information may be useful, e.g.
batch_input_shape = tuple(batch_imgs[0].size()[-2:])
for data_samples, pad_shape in zip(batch_data_samples,
batch_pad_shape):
data_samples.set_metainfo({
'batch_input_shape': batch_input_shape,
'pad_shape': pad_shape
})
if self.pad_mask:
self.pad_gt_masks(batch_data_samples)
if self.pad_seg:
self.pad_gt_sem_seg(batch_data_samples)
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
batch_imgs, batch_data_samples = batch_aug(
batch_imgs, batch_data_samples)
else:
imgs = None
batch_inputs_dict = {
'points': points,
'imgs': batch_imgs if imgs is not None else None
}
return batch_inputs_dict, batch_data_samples
def collate_data(
self, data: Sequence[dict]) -> Tuple[List[dict], Optional[list]]:
"""Collating and copying data to the target device.
Collates the data sampled from dataloader into a list of dict and
list of labels, and then copies tensor to the target device.
Args:
data (Sequence[dict]): Data sampled from dataloader.
Returns:
Tuple[List[Dict], Optional[list]]: Unstacked list of input
data dict and list of labels at target device.
"""
# rewrite `collate_data` since the inputs is a dict instead of
# image tensor.
inputs_dict = [{
k: v.to(self._device)
for k, v in _data['inputs'].items()
} for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
if 'data_sample' in _data:
batch_data_samples.append(_data['data_sample'])
# Move data from CPU to corresponding device.
batch_data_samples = [
data_sample.to(self._device) for data_sample in batch_data_samples
]
if not batch_data_samples:
batch_data_samples = None # type: ignore
return inputs_dict, batch_data_samples
def _get_pad_shape(self, data: Sequence[dict]) -> List[tuple]:
"""Get the pad_shape of each image based on data and
pad_size_divisor."""
# rewrite `_get_pad_shape` for obaining image inputs.
ori_inputs = [_data['inputs']['img'] for _data in data]
batch_pad_shape = []
for ori_input in ori_inputs:
pad_h = int(np.ceil(ori_input.shape[1] /
self.pad_size_divisor)) * self.pad_size_divisor
pad_w = int(np.ceil(ori_input.shape[2] /
self.pad_size_divisor)) * self.pad_size_divisor
batch_pad_shape.append((pad_h, pad_w))
return batch_pad_shape
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from mmengine.data import InstanceData
from torch.optim import Optimizer
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType,
OptMultiConfig, OptSampleList, SampleList)
from mmdet3d.registry import MODELS
from mmdet.core.utils import stack_batch
from mmdet.models.detectors import BaseDetector
from mmdet.models import BaseDetector
@MODELS.register_module()
......@@ -16,191 +11,89 @@ class Base3DDetector(BaseDetector):
"""Base class for 3D detectors.
Args:
preprocess_cfg (dict, optional): Model preprocessing config
for processing the input data. it usually includes
``to_rgb``, ``pad_size_divisor``, ``pad_value``,
``mean`` and ``std``. Default to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
def __init__(self,
preprocess_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None) -> None:
super(Base3DDetector, self).__init__(
preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
def forward_simple_test(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
data_processor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(data_preprocessor=data_processor, init_cfg=init_cfg)
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
def forward(self,
batch_inputs_dict: dict,
batch_data_samples: OptSampleList = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
The method should accept three modes: "tensor", "predict" and "loss":
Returns:
list(obj:`Det3DDataSample`): Detection results of the
input images. Each DetDataSample usually contains
``pred_instances_3d`` or ``pred_panoptic_seg_3d`` or
``pred_sem_seg_3d``.
"""
batch_size = len(batch_data_samples)
batch_input_metas = []
if batch_size != len(batch_inputs_dict['points']):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(batch_inputs_dict['points']), len(batch_input_metas)))
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
for var, name in [(batch_inputs_dict['points'], 'points'),
(batch_input_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
if batch_size == 1:
return self.simple_test(
batch_inputs_dict, batch_input_metas, rescale=True, **kwargs)
else:
return self.aug_test(
batch_inputs_dict, batch_input_metas, rescale=True, **kwargs)
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
def forward(self,
data: List[dict],
optimizer: Optional[Union[Optimizer, dict]] = None,
return_loss: bool = False,
**kwargs):
"""The iteration step during training and testing. This method defines
an iteration step during training and testing, except for the back
propagation and optimizer updating during training, which are done in
an optimizer scheduler.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
data (list[dict]): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer`, dict, Optional): The
optimizer of runner. This argument is unused and reserved.
Default to None.
return_loss (bool): Whether to return loss. In general,
it will be set to True during training and False
during testing. Default to False.
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
during training
dict: It should contain at least 3 keys: ``loss``,
``log_vars``, ``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model
is DDP, it means the batch size on each GPU), which is
used for averaging the logs.
during testing
list(obj:`Det3DDataSample`): Detection results of the
input samples. Each DetDataSample usually contains
``pred_instances_3d`` or ``pred_panoptic_seg_3d`` or
``pred_sem_seg_3d``.
"""
The return type depends on ``mode``.
batch_inputs_dict, batch_data_samples = self.preprocess_data(data)
if return_loss:
losses = self.forward_train(batch_inputs_dict, batch_data_samples,
**kwargs)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(batch_data_samples))
return outputs
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs_dict, batch_data_samples, **kwargs)
elif mode == 'predict':
return self.predict(batch_inputs_dict, batch_data_samples,
**kwargs)
elif mode == 'tensor':
return self._forward(batch_inputs_dict, batch_data_samples,
**kwargs)
else:
return self.forward_simple_test(batch_inputs_dict,
batch_data_samples, **kwargs)
def preprocess_data(self, data: List[dict]) -> tuple:
""" Process input data during training and simple testing phases.
Args:
data (list[dict]): The data to be processed, which
comes from dataloader.
Returns:
tuple: It should contain 2 item.
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
- batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
def convert_to_datasample(self, results_list: InstanceList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
Subclasses could override it to be compatible for some multi-modality
3D detectors.
- batch_data_samples (list[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d` , `gt_instances`.
"""
batch_data_samples = [
data_['data_sample'].to(self.device) for data_ in data
]
if 'points' in data[0]['inputs'].keys():
points = [
data_['inputs']['points'].to(self.device) for data_ in data
]
else:
raise KeyError(
"Model input dict needs to include the 'points' key.")
if 'img' in data[0]['inputs'].keys():
imgs = [data_['inputs']['img'].to(self.device) for data_ in data]
else:
imgs = None
if self.preprocess_cfg is None:
batch_inputs_dict = {
'points': points,
'imgs': stack_batch(imgs).float() if imgs is not None else None
}
return batch_inputs_dict, batch_data_samples
if self.to_rgb and imgs[0].size(0) == 3:
imgs = [_img[[2, 1, 0], ...] for _img in imgs]
imgs = [(_img - self.pixel_mean) / self.pixel_std for _img in imgs]
batch_img = stack_batch(imgs, self.pad_size_divisor, self.pad_value)
batch_inputs_dict = {'points': points, 'imgs': batch_img}
return batch_inputs_dict, batch_data_samples
def postprocess_result(self, results_list: List[InstanceData]) \
-> List[Det3DDataSample]:
""" Convert results list to `Det3DDataSample`.
Args:
results_list (list[:obj:`InstanceData`]): Detection results of
each sample.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3dd`` usually
input. Each Det3DDataSample usually contains
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of 3D bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
out_results_list = []
for i in range(len(results_list)):
result = Det3DDataSample()
result.pred_instances_3d = results_list[i]
results_list[i] = result
return results_list
def show_results(self, data, result, out_dir, show=False, score_thr=None):
# TODO
pass
out_results_list.append(result)
return out_results_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
from typing import List, Tuple, Union
import torch
from mmdet3d.core.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from mmdet3d.registry import MODELS
from .base import Base3DDetector
......@@ -11,7 +13,10 @@ from .base import Base3DDetector
class SingleStage3DDetector(Base3DDetector):
"""SingleStage3DDetector.
This class serves as a base class for single-stage 3D detectors.
This class serves as a base class for single-stage 3D detectors which
directly and densely predict 3D bounding boxes on the output features
of the backbone+neck.
Args:
backbone (dict): Config dict of detector's backbone.
......@@ -21,21 +26,22 @@ class SingleStage3DDetector(Base3DDetector):
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
pretrained (str, optional): Path of pretrained models.
Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
def __init__(self,
backbone,
neck: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
preprocess_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
pretrained: Optional[str] = None) -> None:
super(SingleStage3DDetector, self).__init__(
preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
backbone: ConfigType,
neck: OptConfigType = None,
bbox_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(data_processor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
......@@ -45,33 +51,99 @@ class SingleStage3DDetector(Base3DDetector):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def forward_dummy(self, batch_inputs: dict) -> tuple:
"""Used for computing network flops.
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> Union[dict, list]:
"""Calculate losses from a batch of inputs dict and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
x = self.extract_feat(batch_inputs_dict)
losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
return losses
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
See `mmdetection/tools/analysis_tools/get_flops.py`
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
x = self.extract_feat(batch_inputs['points'])
try:
sample_mod = self.train_cfg.sample_mod
outs = self.bbox_head(x, sample_mod)
except AttributeError:
outs = self.bbox_head(x)
return outs
def extract_feat(self, points: List[torch.Tensor]) -> list:
x = self.extract_feat(batch_inputs_dict)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(results_list)
return predictions
def _forward(self,
batch_inputs_dict: dict,
data_samples: OptSampleList = None,
**kwargs) -> Tuple[List[torch.Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward.
"""
x = self.extract_feat(batch_inputs_dict)
results = self.bbox_head.forward(x)
return results
def extract_feat(self,
batch_inputs_dict: torch.Tensor) -> Tuple[torch.Tensor]:
"""Directly extract features from the backbone+neck.
Args:
points (List[torch.Tensor]): Input points.
points (torch.Tensor): Input points.
"""
x = self.backbone(points[0])
points = batch_inputs_dict['points']
stack_points = torch.stack(points)
x = self.backbone(stack_points)
if self.with_neck:
x = self.neck(x)
return x
def extract_feats(self, batch_inputs_dict: dict) -> list:
"""Extract features of multiple samples."""
return [
self.extract_feat([points])
for points in batch_inputs_dict['points']
]
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmdet3d.core import Det3DDataSample
from mmdet3d.models.data_preprocessors import Det3DDataPreprocessor
class TestDet3DDataPreprocessor(TestCase):
def test_init(self):
# test mean is None
processor = Det3DDataPreprocessor()
self.assertTrue(not hasattr(processor, 'mean'))
self.assertTrue(processor._enable_normalize is False)
# test mean is not None
processor = Det3DDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
self.assertTrue(hasattr(processor, 'mean'))
self.assertTrue(hasattr(processor, 'std'))
self.assertTrue(processor._enable_normalize)
# please specify both mean and std
with self.assertRaises(AssertionError):
Det3DDataPreprocessor(mean=[0, 0, 0])
# bgr2rgb and rgb2bgr cannot be set to True at the same time
with self.assertRaises(AssertionError):
Det3DDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True)
def test_forward(self):
processor = Det3DDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
points = torch.randn((5000, 3))
image = torch.randint(0, 256, (3, 11, 10))
inputs_dict = dict(points=points, img=image)
data = [{'inputs': inputs_dict, 'data_sample': Det3DDataSample()}]
inputs, data_samples = processor(data)
self.assertEqual(inputs['imgs'].shape, (1, 3, 11, 10))
self.assertEqual(len(inputs['points']), 1)
self.assertEqual(len(data_samples), 1)
# test image channel_conversion
processor = Det3DDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
self.assertEqual(inputs['imgs'].shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test image padding
data = [{
'inputs': {
'points': torch.randn((5000, 3)),
'img': torch.randint(0, 256, (3, 10, 11))
}
}, {
'inputs': {
'points': torch.randn((5000, 3)),
'img': torch.randint(0, 256, (3, 9, 14))
}
}]
processor = Det3DDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
self.assertEqual(inputs['imgs'].shape, (2, 3, 10, 14))
self.assertIsNone(data_samples)
# test pad_size_divisor
data = [{
'inputs': {
'points': torch.randn((5000, 3)),
'img': torch.randint(0, 256, (3, 10, 11))
},
'data_sample': Det3DDataSample()
}, {
'inputs': {
'points': torch.randn((5000, 3)),
'img': torch.randint(0, 256, (3, 9, 24))
},
'data_sample': Det3DDataSample()
}]
processor = Det3DDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], pad_size_divisor=5)
inputs, data_samples = processor(data)
self.assertEqual(inputs['imgs'].shape, (2, 3, 10, 25))
self.assertEqual(len(data_samples), 2)
for data_sample, expected_shape in zip(data_samples, [(10, 15),
(10, 25)]):
self.assertEqual(data_sample.pad_shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment