Unverified Commit bc849cc9 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Support TTA for Segmentor (#2382)

* support TTA for segmentor

* fix UT

* fix UT

* support scale and rotate tta

* update with comments
parent e1155c0e
......@@ -73,25 +73,6 @@ test_pipeline = [
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
......@@ -109,6 +90,33 @@ eval_pipeline = [
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]
# train on area 1, 2, 3, 4, 6
# test on area 5
......@@ -157,3 +165,5 @@ test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
tta_model = dict(type='Seg3DTTAModel')
......@@ -73,25 +73,6 @@ test_pipeline = [
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
......@@ -109,6 +90,33 @@ eval_pipeline = [
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]
train_dataloader = dict(
batch_size=8,
......@@ -152,3 +160,5 @@ test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
tta_model = dict(type='Seg3DTTAModel')
......@@ -82,7 +82,7 @@ train_pipeline = [
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='PointSegClassMapping'),
dict(
type='RandomFlip3D',
sync_2d=False,
......@@ -112,12 +112,21 @@ test_pipeline = [
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
......@@ -133,46 +142,75 @@ eval_pipeline = [
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='PointSegClassMapping'),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=1.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=1.)
],
[
dict(
type='GlobalRotScaleTrans',
rot_range=[pcd_rotate_range, pcd_rotate_range],
scale_ratio_range=[
pcd_scale_factor, pcd_scale_factor
],
translation_std=[0, 0, 0])
for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816]
for pcd_scale_factor in [0.95, 1.0, 1.05]
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args))
test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args))
val_dataloader = test_dataloader
......@@ -182,3 +220,5 @@ test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
tta_model = dict(type='Seg3DTTAModel')
......@@ -24,7 +24,7 @@ train_pipeline = [
]
train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))
lr = 0.24
optim_wrapper = dict(
......
......@@ -24,7 +24,7 @@ train_pipeline = [
]
train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))
lr = 0.24
optim_wrapper = dict(
......
# Copyright (c) OpenMMLab. All rights reserved.
from .dbsampler import DataBaseSampler
from .formating import Pack3DDetInputs
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
MonoDet3DInferencerLoader,
from .loading import (LidarDet3DInferencerLoader, LoadAnnotations3D,
LoadImageFromFileMono3D, LoadMultiViewImageFromFiles,
LoadPointsFromDict, LoadPointsFromFile,
LoadPointsFromMultiSweeps, MonoDet3DInferencerLoader,
MultiModalityDet3DInferencerLoader, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
......
......@@ -3,5 +3,9 @@ from .base import Base3DSegmentor
from .cylinder3d import Cylinder3D
from .encoder_decoder import EncoderDecoder3D
from .minkunet import MinkUNet
from .seg3d_tta import Seg3DTTAModel
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet']
__all__ = [
'Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet',
'Seg3DTTAModel'
]
......@@ -132,17 +132,12 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
"""
pass
@abstractmethod
def aug_test(self, batch_inputs, batch_data_samples):
"""Placeholder for augmentation test."""
pass
def postprocess_result(self, seg_pred_list: List[dict],
def postprocess_result(self, seg_logits_list: List[Tensor],
batch_data_samples: SampleList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits_list (List[Tensor]): List of segmentation results,
seg_logits from model of each input point clouds sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
......@@ -152,12 +147,19 @@ class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
for i in range(len(seg_pred_list)):
seg_pred = seg_pred_list[i]
batch_data_samples[i].set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
for i in range(len(seg_logits_list)):
seg_logits = seg_logits_list[i]
seg_pred = seg_logits.argmax(dim=0)
batch_data_samples[i].set_data({
'pts_seg_logits':
PointData(**{'pts_seg_logits': seg_logits}),
'pred_pts_seg':
PointData(**{'pts_semantic_mask': seg_pred})
})
return batch_data_samples
......@@ -127,16 +127,18 @@ class Cylinder3D(EncoderDecoder3D):
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
x = self.extract_feat(batch_inputs_dict)
seg_pred_list = self.decode_head.predict(x, batch_inputs_dict,
batch_data_samples)
for i in range(len(seg_pred_list)):
seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu()
seg_logits_list = self.decode_head.predict(x, batch_inputs_dict,
batch_data_samples)
for i in range(len(seg_logits_list)):
seg_logits_list[i] = seg_logits_list[i].transpose(0, 1)
return self.postprocess_result(seg_pred_list, batch_data_samples)
return self.postprocess_result(seg_logits_list, batch_data_samples)
......@@ -5,7 +5,6 @@ import numpy as np
import torch
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
......@@ -477,8 +476,7 @@ class EncoderDecoder3D(Base3DSegmentor):
else:
seg_logit = self.whole_inference(points, batch_input_metas,
rescale)
output = F.softmax(seg_logit, dim=1)
return output
return seg_logit
def predict(self,
batch_inputs_dict: dict,
......@@ -503,27 +501,26 @@ class EncoderDecoder3D(Base3DSegmentor):
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
seg_pred_list = []
seg_logits_list = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
points = batch_inputs_dict['points']
for point, input_meta in zip(points, batch_input_metas):
seg_prob = self.inference(
seg_logits = self.inference(
point.unsqueeze(0), [input_meta], rescale)[0]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred_list.append(seg_map)
seg_logits_list.append(seg_logits)
return self.postprocess_result(seg_pred_list, batch_data_samples)
return self.postprocess_result(seg_logits_list, batch_data_samples)
def _forward(self,
batch_inputs_dict: dict,
......@@ -546,7 +543,3 @@ class EncoderDecoder3D(Base3DSegmentor):
points = torch.stack(batch_inputs_dict['points'])
x = self.extract_feat(points)
return self.decode_head.forward(x)
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
......@@ -50,7 +50,8 @@ class MinkUNet(EncoderDecoder3D):
losses = self.decode_head.loss(x, data_samples, self.train_cfg)
return losses
def predict(self, inputs: dict, data_samples: SampleList) -> SampleList:
def predict(self, inputs: dict,
batch_data_samples: SampleList) -> SampleList:
"""Simple test with single scene.
Args:
......@@ -67,14 +68,17 @@ class MinkUNet(EncoderDecoder3D):
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
x = self.extract_feat(inputs)
seg_logits = self.decode_head.predict(x, data_samples)
seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits]
seg_logits_list = self.decode_head.predict(x, batch_data_samples)
for i in range(len(seg_logits_list)):
seg_logits_list[i] = seg_logits_list[i].transpose(0, 1)
return self.postprocess_result(seg_preds, data_samples)
return self.postprocess_result(seg_logits_list, batch_data_samples)
def _forward(self,
batch_inputs_dict: dict,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.model import BaseTTAModel
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
@MODELS.register_module()
class Seg3DTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[:obj:`Det3DDataSample`]]): List of
predictions of all enhanced data.
Returns:
List[:obj:`Det3DDataSample`]: Merged prediction.
"""
predictions = []
for data_samples in data_samples_list:
seg_logits = data_samples[0].pts_seg_logits.pts_seg_logits
logits = torch.zeros(seg_logits.shape).to(seg_logits)
for data_sample in data_samples:
seg_logit = data_sample.pts_seg_logits.pts_seg_logits
logits += seg_logit.softmax(dim=0)
logits /= len(data_samples)
seg_pred = logits.argmax(dim=0)
data_samples[0].pred_pts_seg.pts_semantic_mask = seg_pred
predictions.append(data_samples[0])
return predictions
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import pytest
from mmengine import DefaultScope
from mmdet3d.datasets.transforms import * # noqa
from mmdet3d.registry import TRANSFORMS
from mmdet3d.structures.points import LiDARPoints
DefaultScope.get_instance('test_multi_scale_flip_aug_3d', scope_name='mmdet3d')
class TestMuitiScaleFlipAug3D(TestCase):
def test_exception(self):
with pytest.raises(TypeError):
tta_transform = dict(
type='TestTimeAug',
transforms=[
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0)
])
TRANSFORMS.build(tta_transform)
def test_multi_scale_flip_aug(self):
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=1.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=1.0,
flip_ratio_bev_vertical=0.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=1.0,
flip_ratio_bev_vertical=1.0)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
tta_module = TRANSFORMS.build(tta_transform)
results = dict()
points = LiDARPoints(np.random.random((100, 4)), 4)
results['points'] = points
tta_results = tta_module(results.copy())
assert [
data_sample.metainfo['pcd_horizontal_flip']
for data_sample in tta_results['data_samples']
] == [False, False, True, True]
assert [
data_sample.metainfo['pcd_vertical_flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, -0.78539816],
scale_ratio_range=[1.0, 1.0],
translation_std=[0, 0, 0]),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1.0, 1.0],
translation_std=[0, 0, 0]),
dict(
type='GlobalRotScaleTrans',
rot_range=[0.78539816, 0.78539816],
scale_ratio_range=[1.0, 1.0],
translation_std=[0, 0, 0])
], [dict(type='Pack3DDetInputs', keys=['points'])]])
tta_module = TRANSFORMS.build(tta_transform)
results = dict()
points = LiDARPoints(np.random.random((100, 4)), 4)
results['points'] = points
tta_results = tta_module(results.copy())
assert [
data_sample.metainfo['pcd_rotation_angle']
for data_sample in tta_results['data_samples']
] == [-0.78539816, 0, 0.78539816]
assert [
data_sample.metainfo['pcd_scale_factor']
for data_sample in tta_results['data_samples']
] == [1.0, 1.0, 1.0]
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[0.95, 0.95],
translation_std=[0, 0, 0]),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1.0, 1.0],
translation_std=[0, 0, 0]),
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1.05, 1.05],
translation_std=[0, 0, 0])
], [dict(type='Pack3DDetInputs', keys=['points'])]])
tta_module = TRANSFORMS.build(tta_transform)
results = dict()
points = LiDARPoints(np.random.random((100, 4)), 4)
results['points'] = points
tta_results = tta_module(results.copy())
assert [
data_sample.metainfo['pcd_rotation_angle']
for data_sample in tta_results['data_samples']
] == [0, 0, 0]
assert [
data_sample.metainfo['pcd_scale_factor']
for data_sample in tta_results['data_samples']
] == [0.95, 1, 1.05]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=1.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=1.0,
flip_ratio_bev_vertical=0.0),
dict(
type='RandomFlip3D',
flip_ratio_bev_horizontal=1.0,
flip_ratio_bev_vertical=1.0)
],
[
dict(
type='GlobalRotScaleTrans',
rot_range=[pcd_rotate_range, pcd_rotate_range],
scale_ratio_range=[pcd_scale_factor, pcd_scale_factor],
translation_std=[0, 0, 0])
for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816]
for pcd_scale_factor in [0.95, 1.0, 1.05]
], [dict(type='Pack3DDetInputs', keys=['points'])]
])
tta_module = TRANSFORMS.build(tta_transform)
results = dict()
points = LiDARPoints(np.random.random((100, 4)), 4)
results['points'] = points
tta_results = tta_module(results.copy())
assert [
data_sample.metainfo['pcd_horizontal_flip']
for data_sample in tta_results['data_samples']
] == [
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True
]
assert [
data_sample.metainfo['pcd_vertical_flip']
for data_sample in tta_results['data_samples']
] == [
False, False, False, False, False, False, False, False, False,
True, True, True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, True, True, True,
True, True, True, True, True, True
]
assert [
data_sample.metainfo['pcd_rotation_angle']
for data_sample in tta_results['data_samples']
] == [
-0.78539816, -0.78539816, -0.78539816, 0.0, 0.0, 0.0, 0.78539816,
0.78539816, 0.78539816, -0.78539816, -0.78539816, -0.78539816, 0.0,
0.0, 0.0, 0.78539816, 0.78539816, 0.78539816, -0.78539816,
-0.78539816, -0.78539816, 0.0, 0.0, 0.0, 0.78539816, 0.78539816,
0.78539816, -0.78539816, -0.78539816, -0.78539816, 0.0, 0.0, 0.0,
0.78539816, 0.78539816, 0.78539816
]
assert [
data_sample.metainfo['pcd_scale_factor']
for data_sample in tta_results['data_samples']
] == [
0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05,
0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05,
0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05, 0.95, 1.0, 1.05
]
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
......
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine import ConfigDict, DefaultScope
from mmdet3d.models import Seg3DTTAModel
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.testing import get_detector_cfg
class TestSeg3DTTAModel(TestCase):
def test_seg3d_tta_model(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'Cylinder3D')
DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d')
segmentor3d_cfg = get_detector_cfg(
'cylinder3d/cylinder3d_4xb4_3x_semantickitti.py')
cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg)
model: Seg3DTTAModel = MODELS.build(cfg)
points = []
data_samples = []
pcd_horizontal_flip_list = [False, False, True, True]
pcd_vertical_flip_list = [False, True, False, True]
for i in range(4):
points.append({'points': [torch.randn(200, 4)]})
data_samples.append([
Det3DDataSample(
metainfo=dict(
pcd_horizontal_flip=pcd_horizontal_flip_list[i],
pcd_vertical_flip=pcd_vertical_flip_list[i]))
])
if torch.cuda.is_available():
model.eval()
model.test_step(dict(inputs=points, data_samples=data_samples))
......@@ -3,7 +3,7 @@ import argparse
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
......@@ -53,6 +53,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
......@@ -109,6 +111,14 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.tta:
# Currently, we only support tta for 3D segmentation
# TODO: Support tta for 3D detection
assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.'
assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` in config.'
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment