Unverified Commit 9d852f17 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PointNet++ Segmentor (#528)

* build BaseSegmentor for point sem seg

* add encoder-decoder segmentor

* update mmseg dependency

* fix linting errors

* warp predicted seg_mask in dict

* add unit test

* use build_model to wrap detector and segmentor

* fix compatibility with mmseg

* faster sliding inference

* merge master

* configs for training on ScanNet

* fix CI errors

* add comments & fix typos

* hard-code class_weight into configs

* fix logger bugs

* update segmentor unit test

* logger use mmdet3d

* use eps to replace hard-coded 1e-3

* add comments

* replace np operation with torch code

* add comments for class_weight

* add comment for BaseSegmentor.simple_test

* rewrite EncoderDecoder3D to avoid inheriting from mmseg
parent 43d79534
...@@ -31,7 +31,7 @@ train_pipeline = [ ...@@ -31,7 +31,7 @@ train_pipeline = [
block_size=1.5, block_size=1.5,
sample_rate=1.0, sample_rate=1.0,
ignore_index=len(class_names), ignore_index=len(class_names),
use_normalized_coord=True), use_normalized_coord=False),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
......
_base_ = './pointnet2_ssg.py'
# model settings
model = dict(
backbone=dict(
_delete_=True,
type='PointNet2SAMSG',
in_channels=6, # [xyz, rgb], should be modified with dataset
num_points=(1024, 256, 64, 16),
radii=((0.05, 0.1), (0.1, 0.2), (0.2, 0.4), (0.4, 0.8)),
num_samples=((16, 32), (16, 32), (16, 32), (16, 32)),
sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96,
128)),
((128, 196, 256), (128, 196, 256)), ((256, 256, 512),
(256, 384, 512))),
aggregation_channels=(None, None, None, None),
fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')),
fps_sample_range_lists=((-1), (-1), (-1), (-1)),
dilated_group=(False, False, False, False),
out_indices=(0, 1, 2, 3),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
decode_head=dict(
fp_channels=((1536, 256, 256), (512, 256, 256), (352, 256, 128),
(128, 128, 128, 128))))
# model settings
model = dict(
type='EncoderDecoder3D',
backbone=dict(
type='PointNet2SASSG',
in_channels=6, # [xyz, rgb], should be modified with dataset
num_points=(1024, 256, 64, 16),
radius=(0.1, 0.2, 0.4, 0.8),
num_samples=(32, 32, 32, 32),
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256,
512)),
fp_channels=(),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
decode_head=dict(
type='PointNet2Head',
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128, 128, 128, 128)),
channels=128,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide'))
_base_ = [
'../_base_/datasets/scannet_seg-3d-20class.py',
'../_base_/models/pointnet2_msg.py', '../_base_/default_runtime.py'
]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=5)
# model settings
model = dict(
decode_head=dict(
num_classes=20,
ignore_index=20,
# `class_weight` is generated in data pre-processing, saved in
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
5.3954206, 4.6971426
])),
test_cfg=dict(
num_points=8192,
block_size=1.5,
sample_rate=0.5,
use_normalized_coord=False,
batch_size=24))
# optimizer
lr = 0.001 # max learning rate
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5)
# runtime settings
checkpoint_config = dict(interval=5)
runner = dict(type='EpochBasedRunner', max_epochs=150)
_base_ = [
'../_base_/datasets/scannet_seg-3d-20class.py',
'../_base_/models/pointnet2_ssg.py', '../_base_/default_runtime.py'
]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=5)
# model settings
model = dict(
decode_head=dict(
num_classes=20,
ignore_index=20,
# `class_weight` is generated in data pre-processing, saved in
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
5.3954206, 4.6971426
])),
test_cfg=dict(
num_points=8192,
block_size=1.5,
sample_rate=0.5,
use_normalized_coord=False,
batch_size=24))
# optimizer
lr = 0.001 # max learning rate
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5)
# runtime settings
checkpoint_config = dict(interval=5)
runner = dict(type='EpochBasedRunner', max_epochs=150)
...@@ -2,9 +2,10 @@ from .inference import (convert_SyncBN, inference_detector, ...@@ -2,9 +2,10 @@ from .inference import (convert_SyncBN, inference_detector,
inference_multi_modality_detector, init_detector, inference_multi_modality_detector, init_detector,
show_result_meshlab) show_result_meshlab)
from .test import single_gpu_test from .test import single_gpu_test
from .train import train_model
__all__ = [ __all__ = [
'inference_detector', 'init_detector', 'single_gpu_test', 'inference_detector', 'init_detector', 'single_gpu_test',
'show_result_meshlab', 'convert_SyncBN', 'show_result_meshlab', 'convert_SyncBN', 'train_model',
'inference_multi_modality_detector' 'inference_multi_modality_detector'
] ]
...@@ -12,7 +12,7 @@ from mmdet3d.core import (Box3DMode, DepthInstance3DBoxes, ...@@ -12,7 +12,7 @@ from mmdet3d.core import (Box3DMode, DepthInstance3DBoxes,
show_result) show_result)
from mmdet3d.core.bbox import get_box_type from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_detector from mmdet3d.models import build_model
def convert_SyncBN(config): def convert_SyncBN(config):
...@@ -52,7 +52,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'): ...@@ -52,7 +52,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
config.model.pretrained = None config.model.pretrained = None
convert_SyncBN(config.model) convert_SyncBN(config.model)
config.model.train_cfg = None config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg')) model = build_model(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint) checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']: if 'CLASSES' in checkpoint['meta']:
......
import mmcv import mmcv
import os
import torch import torch
from mmcv.image import tensor2imgs from mmcv.image import tensor2imgs
from os import path as osp
from mmdet3d.models import Base3DDetector, Base3DSegmentor
def single_gpu_test(model, def single_gpu_test(model,
...@@ -35,11 +37,11 @@ def single_gpu_test(model, ...@@ -35,11 +37,11 @@ def single_gpu_test(model,
result = model(return_loss=False, rescale=True, **data) result = model(return_loss=False, rescale=True, **data)
if show: if show:
# Visualize the results of MMdetection3D model # Visualize the results of MMDetection3D model
# 'show_results' is MMdetection3D visualization API # 'show_results' is MMdetection3D visualization API
if hasattr(model.module, 'show_results'): if isinstance(model.module, (Base3DDetector, Base3DSegmentor)):
model.module.show_results(data, result, out_dir) model.module.show_results(data, result, out_dir)
# Visualize the results of MMdetection model # Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API # 'show_result' is MMdetection visualization API
else: else:
batch_size = len(result) batch_size = len(result)
...@@ -60,8 +62,7 @@ def single_gpu_test(model, ...@@ -60,8 +62,7 @@ def single_gpu_test(model,
img_show = mmcv.imresize(img_show, (ori_w, ori_h)) img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir: if out_dir:
out_file = os.path.join(out_dir, out_file = osp.join(out_dir, img_meta['ori_filename'])
img_meta['ori_filename'])
else: else:
out_file = None out_file = None
......
from mmdet.apis import train_detector
from mmseg.apis import train_segmentor
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""A function wrapper for launching model training according to cfg.
Because we need different eval_hook in runner. Should be deprecated in the
future.
"""
if cfg.model.type in ['EncoderDecoder3D']:
train_segmentor(
model,
dataset,
cfg,
distributed=distributed,
validate=validate,
timestamp=timestamp,
meta=meta)
else:
train_detector(
model,
dataset,
cfg,
distributed=distributed,
validate=validate,
timestamp=timestamp,
meta=meta)
...@@ -6,11 +6,13 @@ from os import path as osp ...@@ -6,11 +6,13 @@ from os import path as osp
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .pipelines import Compose from .pipelines import Compose
from .utils import get_loading_pipeline from .utils import get_loading_pipeline
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class Custom3DSegDataset(Dataset): class Custom3DSegDataset(Dataset):
"""Customized 3D dataset for semantic segmentation task. """Customized 3D dataset for semantic segmentation task.
...@@ -143,7 +145,7 @@ class Custom3DSegDataset(Dataset): ...@@ -143,7 +145,7 @@ class Custom3DSegDataset(Dataset):
results['pts_seg_fields'] = [] results['pts_seg_fields'] = []
results['mask_fields'] = [] results['mask_fields'] = []
results['seg_fields'] = [] results['seg_fields'] = []
results['gt_bboxes_3d'] = [] results['bbox3d_fields'] = []
def prepare_train_data(self, index): def prepare_train_data(self, index):
"""Training data preparation. """Training data preparation.
......
...@@ -3,11 +3,13 @@ from os import path as osp ...@@ -3,11 +3,13 @@ from os import path as osp
from mmdet3d.core import show_seg_result from mmdet3d.core import show_seg_result
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .custom_3d_seg import Custom3DSegDataset from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose from .pipelines import Compose
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class _S3DISSegDataset(Custom3DSegDataset): class _S3DISSegDataset(Custom3DSegDataset):
r"""S3DIS Dataset for Semantic Segmentation Task. r"""S3DIS Dataset for Semantic Segmentation Task.
......
...@@ -6,6 +6,7 @@ from os import path as osp ...@@ -6,6 +6,7 @@ from os import path as osp
from mmdet3d.core import show_result, show_seg_result from mmdet3d.core import show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose from .pipelines import Compose
...@@ -196,6 +197,7 @@ class ScanNetDataset(Custom3DDataset): ...@@ -196,6 +197,7 @@ class ScanNetDataset(Custom3DDataset):
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset): class ScanNetSegDataset(Custom3DSegDataset):
r"""ScanNet Dataset for Semantic Segmentation Task. r"""ScanNet Dataset for Semantic Segmentation Task.
......
from .backbones import * # noqa: F401,F403 from .backbones import * # noqa: F401,F403
from .builder import (FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS, from .builder import (FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS,
build_backbone, build_detector, build_fusion_layer, build_backbone, build_detector, build_fusion_layer,
build_head, build_loss, build_middle_encoder, build_neck, build_head, build_loss, build_middle_encoder,
build_roi_extractor, build_shared_head, build_model, build_neck, build_roi_extractor,
build_voxel_encoder) build_shared_head, build_voxel_encoder)
from .decode_heads import * # noqa: F401,F403 from .decode_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403 from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403
...@@ -13,11 +13,12 @@ from .middle_encoders import * # noqa: F401,F403 ...@@ -13,11 +13,12 @@ from .middle_encoders import * # noqa: F401,F403
from .model_utils import * # noqa: F401,F403 from .model_utils import * # noqa: F401,F403
from .necks import * # noqa: F401,F403 from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403 from .roi_heads import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .voxel_encoders import * # noqa: F401,F403 from .voxel_encoders import * # noqa: F401,F403
__all__ = [ __all__ = [
'VOXEL_ENCODERS', 'MIDDLE_ENCODERS', 'FUSION_LAYERS', 'build_backbone', 'VOXEL_ENCODERS', 'MIDDLE_ENCODERS', 'FUSION_LAYERS', 'build_backbone',
'build_neck', 'build_roi_extractor', 'build_shared_head', 'build_head', 'build_neck', 'build_roi_extractor', 'build_shared_head', 'build_head',
'build_loss', 'build_detector', 'build_fusion_layer', 'build_loss', 'build_detector', 'build_fusion_layer', 'build_model',
'build_middle_encoder', 'build_voxel_encoder' 'build_middle_encoder', 'build_voxel_encoder'
] ]
...@@ -3,6 +3,7 @@ from mmcv.utils import Registry ...@@ -3,6 +3,7 @@ from mmcv.utils import Registry
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build) ROI_EXTRACTORS, SHARED_HEADS, build)
from mmseg.models.builder import SEGMENTORS
VOXEL_ENCODERS = Registry('voxel_encoder') VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder') MIDDLE_ENCODERS = Registry('middle_encoder')
...@@ -52,6 +53,31 @@ def build_detector(cfg, train_cfg=None, test_cfg=None): ...@@ -52,6 +53,31 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
"""Build segmentor."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_model(cfg, train_cfg=None, test_cfg=None):
"""A function warpper for building 3D detector or segmentor according to
cfg.
Should be deprecated in the future.
"""
if cfg.type in ['EncoderDecoder3D']:
return build_segmentor(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
else:
return build_detector(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
def build_voxel_encoder(cfg): def build_voxel_encoder(cfg):
"""Build voxel encoder.""" """Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS) return build(cfg, VOXEL_ENCODERS)
......
...@@ -65,13 +65,13 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta): ...@@ -65,13 +65,13 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
"""Placeholder of forward function.""" """Placeholder of forward function."""
pass pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): def forward_train(self, inputs, img_metas, pts_semantic_mask, train_cfg):
"""Forward function for training. """Forward function for training.
Args: Args:
inputs (list[Tensor]): List of multi-level point features. inputs (list[torch.Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample. img_metas (list[dict]): Meta information of each sample.
gt_semantic_seg (torch.Tensor): Semantic segmentation masks pts_semantic_mask (torch.Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task. used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config. train_cfg (dict): The training config.
...@@ -79,7 +79,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta): ...@@ -79,7 +79,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components dict[str, Tensor]: a dictionary of loss components
""" """
seg_logits = self.forward(inputs) seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg) losses = self.losses(seg_logits, pts_semantic_mask)
return losses return losses
def forward_test(self, inputs, img_metas, test_cfg): def forward_test(self, inputs, img_metas, test_cfg):
......
from .base import Base3DSegmentor
from .encoder_decoder import EncoderDecoder3D
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D']
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp
from mmdet3d.core import show_seg_result
from mmseg.models.segmentors import BaseSegmentor
class Base3DSegmentor(BaseSegmentor):
"""Base class for 3D segmentors.
The main difference with `BaseSegmentor` is that we modify the keys in
data_dict and use a 3D seg specific visualization function.
"""
def forward_test(self, points, img_metas, **kwargs):
"""Calls either simple_test or aug_test depending on the length of
outer list of points. If len(points) == 1, call simple_test. Otherwise
call aug_test to aggregate the test results by e.g. voting.
Args:
points (list[list[torch.Tensor]]): the outer list indicates
test-time augmentations and inner torch.Tensor should have a
shape BXNxC, which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(points)}) != '
f'num of image meta ({len(img_metas)})')
if num_augs == 1:
return self.simple_test(points[0], img_metas[0], **kwargs)
else:
return self.aug_test(points, img_metas, **kwargs)
@auto_fp16(apply_to=('points'))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, point and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, point and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def show_results(self,
data,
result,
palette=None,
out_dir=None,
ignore_index=None):
"""Results visualization.
Args:
data (list[dict]): Input points and the information of the sample.
result (list[dict]): Prediction results.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
out_dir (str): Output directory of visualization result.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
if palette is None:
if self.PALETTE is None:
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
else:
palette = self.PALETTE
palette = np.array(palette)
for batch_id in range(len(result)):
if isinstance(data['points'][0], DC):
points = data['points'][0]._data[0][batch_id].numpy()
elif mmcv.is_list_of(data['points'][0], torch.Tensor):
points = data['points'][0][batch_id]
else:
ValueError(f"Unsupported data type {type(data['points'][0])} "
f'for visualization!')
if isinstance(data['img_metas'][0], DC):
pts_filename = data['img_metas'][0]._data[0][batch_id][
'pts_filename']
elif mmcv.is_list_of(data['img_metas'][0], dict):
pts_filename = data['img_metas'][0][batch_id]['pts_filename']
else:
ValueError(
f"Unsupported data type {type(data['img_metas'][0])} "
f'for visualization!')
file_name = osp.split(pts_filename)[-1].split('.')[0]
pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy()
show_seg_result(points, None, pred_sem_mask, out_dir, file_name,
palette, ignore_index)
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from mmseg.core import add_prefix
from mmseg.models import SEGMENTORS
from ..builder import build_backbone, build_head, build_neck
from .base import Base3DSegmentor
@SEGMENTORS.register_module()
class EncoderDecoder3D(Base3DSegmentor):
"""3D Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be thrown during inference.
"""
def __init__(self,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(EncoderDecoder3D, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head, \
'3D EncoderDecoder Segmentor should have a decode_head'
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = build_head(decode_head)
self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(build_head(head_cfg))
else:
self.auxiliary_head = build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder3D, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, points):
"""Extract features from points."""
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, points, img_metas):
"""Encode points with backbone and decode into a semantic segmentation
map of the same size as input.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample.
Returns:
torch.Tensor: Segmentation logits of shape [B, num_classes, N].
"""
x = self.extract_feat(points)
out = self._decode_head_forward_test(x, img_metas)
return out
def _decode_head_forward_train(self, x, img_metas, pts_semantic_mask):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas,
pts_semantic_mask,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return seg_logits
def _auxiliary_head_forward_train(self, x, img_metas, pts_semantic_mask):
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas,
pts_semantic_mask,
self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.forward_train(
x, img_metas, pts_semantic_mask, self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def forward_dummy(self, points):
"""Dummy forward function."""
seg_logit = self.encode_decode(points, None)
return seg_logit
def forward_train(self, points, img_metas, pts_semantic_mask):
"""Forward function for training.
Args:
points (list[torch.Tensor]): List of points of shape [N, C].
img_metas (list): Image metas.
pts_semantic_mask (list[torch.Tensor]): List of point-wise semantic
labels of shape [N].
Returns:
dict[str, Tensor]: Losses.
"""
points_cat = torch.stack(points)
pts_semantic_mask_cat = torch.stack(pts_semantic_mask)
# extract features using backbone
x = self.extract_feat(points_cat)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas,
pts_semantic_mask_cat)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, img_metas, pts_semantic_mask_cat)
losses.update(loss_aux)
return losses
@staticmethod
def _input_generation(coords,
patch_center,
coord_max,
feats,
use_normalized_coord=False):
"""Generating model input.
Generate input by subtracting patch center and adding additional \
features. Currently support colors and normalized xyz as features.
Args:
coords (torch.Tensor): Sampled 3D point coordinate of shape [S, 3].
patch_center (torch.Tensor): Center coordinate of the patch.
coord_max (torch.Tensor): Max coordinate of all 3D points.
feats (torch.Tensor): Features of sampled points of shape [S, C].
use_normalized_coord (bool, optional): Whether to use normalized \
xyz as additional features. Defaults to False.
Returns:
torch.Tensor: The generated input data of shape [S, 3+C'].
"""
# subtract patch center, the z dimension is not centered
centered_coords = coords.clone()
centered_coords[:, 0] -= patch_center[0]
centered_coords[:, 1] -= patch_center[1]
# normalized coordinates as extra features
if use_normalized_coord:
normalized_coord = coords / coord_max
feats = torch.cat([feats, normalized_coord], dim=1)
points = torch.cat([centered_coords, feats], dim=1)
return points
def _sliding_patch_generation(self,
points,
num_points,
block_size,
sample_rate=0.5,
use_normalized_coord=False,
eps=1e-3):
"""Sampling points in a sliding window fashion.
First sample patches to cover all the input points.
Then sample points in each patch to batch points of a certain number.
Args:
points (torch.Tensor): Input points of shape [N, 3+C].
num_points (int): Number of points to be sampled in each patch.
block_size (float, optional): Size of a patch to sample.
sample_rate (float, optional): Stride used in sliding patch.
Defaults to 0.5.
use_normalized_coord (bool, optional): Whether to use normalized \
xyz as additional features. Defaults to False.
eps (float, optional): A value added to patch boundary to guarantee
points coverage. Default 1e-3.
Returns:
np.ndarray | np.ndarray:
- patch_points (torch.Tensor): Points of different patches of \
shape [K, N, 3+C].
- patch_idxs (torch.Tensor): Index of each point in \
`patch_points`, of shape [K, N].
"""
device = points.device
# we assume the first three dims are points' 3D coordinates
# and the rest dims are their per-point features
coords = points[:, :3]
feats = points[:, 3:]
coord_max = coords.max(0)[0]
coord_min = coords.min(0)[0]
stride = block_size * sample_rate
num_grid_x = int(
torch.ceil((coord_max[0] - coord_min[0] - block_size) /
stride).item() + 1)
num_grid_y = int(
torch.ceil((coord_max[1] - coord_min[1] - block_size) /
stride).item() + 1)
patch_points, patch_idxs = [], []
for idx_y in range(num_grid_y):
s_y = coord_min[1] + idx_y * stride
e_y = torch.min(s_y + block_size, coord_max[1])
s_y = e_y - block_size
for idx_x in range(num_grid_x):
s_x = coord_min[0] + idx_x * stride
e_x = torch.min(s_x + block_size, coord_max[0])
s_x = e_x - block_size
# extract points within this patch
cur_min = torch.tensor([s_x, s_y, coord_min[2]]).to(device)
cur_max = torch.tensor([e_x, e_y, coord_max[2]]).to(device)
cur_choice = ((coords >= cur_min - eps) &
(coords <= cur_max + eps)).all(dim=1)
if not cur_choice.any(): # no points in this patch
continue
# sample points in this patch to multiple batches
cur_center = cur_min + block_size / 2.0
point_idxs = torch.nonzero(cur_choice, as_tuple=True)[0]
num_batch = int(np.ceil(point_idxs.shape[0] / num_points))
point_size = int(num_batch * num_points)
replace = point_size > 2 * point_idxs.shape[0]
num_repeat = point_size - point_idxs.shape[0]
if replace: # duplicate
point_idxs_repeat = point_idxs[torch.randint(
0, point_idxs.shape[0],
size=(num_repeat, )).to(device)]
else:
point_idxs_repeat = point_idxs[torch.randperm(
point_idxs.shape[0])[:num_repeat]]
choices = torch.cat([point_idxs, point_idxs_repeat], dim=0)
choices = choices[torch.randperm(choices.shape[0])]
# construct model input
point_batches = self._input_generation(
coords[choices],
cur_center,
coord_max,
feats[choices],
use_normalized_coord=use_normalized_coord)
patch_points.append(point_batches)
patch_idxs.append(choices)
patch_points = torch.cat(patch_points, dim=0)
patch_idxs = torch.cat(patch_idxs, dim=0)
# make sure all points are sampled at least once
assert torch.unique(patch_idxs).shape[0] == points.shape[0], \
'some points are not sampled in sliding inference'
return patch_points, patch_idxs
def slide_inference(self, point, img_meta, rescale):
"""Inference by sliding-window with overlap.
Args:
point (torch.Tensor): Input points of shape [N, 3+C].
img_meta (dict): Meta information of input sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Returns:
Tensor: The output segmentation map of shape [num_classes, N].
"""
num_points = self.test_cfg.num_points
block_size = self.test_cfg.block_size
sample_rate = self.test_cfg.sample_rate
use_normalized_coord = self.test_cfg.use_normalized_coord
batch_size = self.test_cfg.batch_size * num_points
# patch_points is of shape [K*N, 3+C], patch_idxs is of shape [K*N]
patch_points, patch_idxs = self._sliding_patch_generation(
point, num_points, block_size, sample_rate, use_normalized_coord)
feats_dim = patch_points.shape[1]
seg_logits = [] # save patch predictions
for batch_idx in range(0, patch_points.shape[0], batch_size):
batch_points = patch_points[batch_idx:batch_idx + batch_size]
batch_points = batch_points.view(-1, num_points, feats_dim)
# batch_seg_logit is of shape [B, num_classes, N]
batch_seg_logit = self.encode_decode(batch_points, img_meta)
batch_seg_logit = batch_seg_logit.transpose(1, 2).contiguous()
seg_logits.append(batch_seg_logit.view(-1, self.num_classes))
# aggregate per-point logits by indexing sum and dividing count
seg_logits = torch.cat(seg_logits, dim=0) # [K*N, num_classes]
expand_patch_idxs = patch_idxs.unsqueeze(1).repeat(1, self.num_classes)
preds = point.new_zeros((point.shape[0], self.num_classes)).\
scatter_add_(dim=0, index=expand_patch_idxs, src=seg_logits)
count_mat = torch.bincount(patch_idxs)
preds = preds / count_mat[:, None]
# TODO: if rescale and voxelization segmentor
return preds.transpose(0, 1) # to [num_classes, K*N]
def whole_inference(self, points, img_metas, rescale):
"""Inference with full scene (one forward pass without sliding)."""
seg_logit = self.encode_decode(points, img_metas)
# TODO: if rescale and voxelization segmentor
return seg_logit
def inference(self, points, img_metas, rescale):
"""Inference with slide/whole style.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Returns:
Tensor: The output segmentation map.
"""
assert self.test_cfg.mode in ['slide', 'whole']
if self.test_cfg.mode == 'slide':
seg_logit = torch.stack([
self.slide_inference(point, img_meta, rescale)
for point, img_meta in zip(points, img_metas)
], 0)
else:
seg_logit = self.whole_inference(points, img_metas, rescale)
output = F.softmax(seg_logit, dim=1)
return output
def simple_test(self, points, img_metas, rescale=True):
"""Simple test with single scene.
Args:
points (list[torch.Tensor]): List of points of shape [N, 3+C].
img_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
list[dict]: The output prediction result with following keys:
- semantic_mask (Tensor): Segmentation mask of shape [N].
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
seg_pred = []
for point, img_meta in zip(points, img_metas):
seg_prob = self.inference(point.unsqueeze(0), [img_meta],
rescale)[0]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred
def aug_test(self, points, img_metas, rescale=True):
"""Test with augmentations.
Args:
points (list[torch.Tensor]): List of points of shape [B, N, 3+C].
img_metas (list[list[dict]]): Meta information of each sample.
Outer list are different samples while inner is different augs.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
list[dict]: The output prediction result with following keys:
- semantic_mask (Tensor): Segmentation mask of shape [N].
"""
# in aug_test, one scene going through different augmentations could
# have the same number of points and are stacked as a batch
# to save memory, we get augmented seg logit inplace
seg_pred = []
for point, img_meta in zip(points, img_metas):
seg_prob = self.inference(point, img_meta, rescale)
seg_prob = seg_prob.mean(0) # [num_classes, N]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred
...@@ -82,12 +82,14 @@ class Points_Sampler(nn.Module): ...@@ -82,12 +82,14 @@ class Points_Sampler(nn.Module):
if fps_sample_range == -1: if fps_sample_range == -1:
sample_points_xyz = points_xyz[:, last_fps_end_index:] sample_points_xyz = points_xyz[:, last_fps_end_index:]
sample_features = features[:, :, last_fps_end_index:] sample_features = features[:, :, last_fps_end_index:] if \
features is not None else None
else: else:
sample_points_xyz = \ sample_points_xyz = \
points_xyz[:, last_fps_end_index:fps_sample_range] points_xyz[:, last_fps_end_index:fps_sample_range]
sample_features = \ sample_features = \
features[:, :, last_fps_end_index:fps_sample_range] features[:, :, last_fps_end_index:fps_sample_range] if \
features is not None else None
fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
npoint) npoint)
...@@ -125,6 +127,8 @@ class FFPS_Sampler(nn.Module): ...@@ -125,6 +127,8 @@ class FFPS_Sampler(nn.Module):
def forward(self, points, features, npoint): def forward(self, points, features, npoint):
"""Sampling points with F-FPS.""" """Sampling points with F-FPS."""
assert features is not None, \
'feature input to FFPS_Sampler should not be None'
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist( features_dist = calc_square_dist(
features_for_fps, features_for_fps, norm=False) features_for_fps, features_for_fps, norm=False)
...@@ -143,6 +147,8 @@ class FS_Sampler(nn.Module): ...@@ -143,6 +147,8 @@ class FS_Sampler(nn.Module):
def forward(self, points, features, npoint): def forward(self, points, features, npoint):
"""Sampling points with FS_Sampling.""" """Sampling points with FS_Sampling."""
assert features is not None, \
'feature input to FS_Sampler should not be None'
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist( features_dist = calc_square_dist(
features_for_fps, features_for_fps, norm=False) features_for_fps, features_for_fps, norm=False)
......
from mmcv.utils import Registry, build_from_cfg, print_log from mmcv.utils import Registry, build_from_cfg, print_log
from mmdet.utils import get_root_logger
from .collect_env import collect_env from .collect_env import collect_env
from .logger import get_root_logger
__all__ = [ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'print_log' 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'print_log'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment