Unverified Commit 9d852f17 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PointNet++ Segmentor (#528)

* build BaseSegmentor for point sem seg

* add encoder-decoder segmentor

* update mmseg dependency

* fix linting errors

* warp predicted seg_mask in dict

* add unit test

* use build_model to wrap detector and segmentor

* fix compatibility with mmseg

* faster sliding inference

* merge master

* configs for training on ScanNet

* fix CI errors

* add comments & fix typos

* hard-code class_weight into configs

* fix logger bugs

* update segmentor unit test

* logger use mmdet3d

* use eps to replace hard-coded 1e-3

* add comments

* replace np operation with torch code

* add comments for class_weight

* add comment for BaseSegmentor.simple_test

* rewrite EncoderDecoder3D to avoid inheriting from mmseg
parent 43d79534
......@@ -31,7 +31,7 @@ train_pipeline = [
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=True),
use_normalized_coord=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
......
_base_ = './pointnet2_ssg.py'
# model settings
model = dict(
backbone=dict(
_delete_=True,
type='PointNet2SAMSG',
in_channels=6, # [xyz, rgb], should be modified with dataset
num_points=(1024, 256, 64, 16),
radii=((0.05, 0.1), (0.1, 0.2), (0.2, 0.4), (0.4, 0.8)),
num_samples=((16, 32), (16, 32), (16, 32), (16, 32)),
sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96,
128)),
((128, 196, 256), (128, 196, 256)), ((256, 256, 512),
(256, 384, 512))),
aggregation_channels=(None, None, None, None),
fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')),
fps_sample_range_lists=((-1), (-1), (-1), (-1)),
dilated_group=(False, False, False, False),
out_indices=(0, 1, 2, 3),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
decode_head=dict(
fp_channels=((1536, 256, 256), (512, 256, 256), (352, 256, 128),
(128, 128, 128, 128))))
# model settings
model = dict(
type='EncoderDecoder3D',
backbone=dict(
type='PointNet2SASSG',
in_channels=6, # [xyz, rgb], should be modified with dataset
num_points=(1024, 256, 64, 16),
radius=(0.1, 0.2, 0.4, 0.8),
num_samples=(32, 32, 32, 32),
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256,
512)),
fp_channels=(),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
decode_head=dict(
type='PointNet2Head',
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128, 128, 128, 128)),
channels=128,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide'))
_base_ = [
'../_base_/datasets/scannet_seg-3d-20class.py',
'../_base_/models/pointnet2_msg.py', '../_base_/default_runtime.py'
]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=5)
# model settings
model = dict(
decode_head=dict(
num_classes=20,
ignore_index=20,
# `class_weight` is generated in data pre-processing, saved in
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
5.3954206, 4.6971426
])),
test_cfg=dict(
num_points=8192,
block_size=1.5,
sample_rate=0.5,
use_normalized_coord=False,
batch_size=24))
# optimizer
lr = 0.001 # max learning rate
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5)
# runtime settings
checkpoint_config = dict(interval=5)
runner = dict(type='EpochBasedRunner', max_epochs=150)
_base_ = [
'../_base_/datasets/scannet_seg-3d-20class.py',
'../_base_/models/pointnet2_ssg.py', '../_base_/default_runtime.py'
]
# data settings
data = dict(samples_per_gpu=16)
evaluation = dict(interval=5)
# model settings
model = dict(
decode_head=dict(
num_classes=20,
ignore_index=20,
# `class_weight` is generated in data pre-processing, saved in
# `data/scannet/seg_info/train_label_weight.npy`
# you can copy paste the values here, or input the file path as
# `class_weight=data/scannet/seg_info/train_label_weight.npy`
loss_decode=dict(class_weight=[
2.389689, 2.7215734, 4.5944676, 4.8543367, 4.096086, 4.907941,
4.690836, 4.512031, 4.623311, 4.9242644, 5.358117, 5.360071,
5.019636, 4.967126, 5.3502126, 5.4023647, 5.4027233, 5.4169416,
5.3954206, 4.6971426
])),
test_cfg=dict(
num_points=8192,
block_size=1.5,
sample_rate=0.5,
use_normalized_coord=False,
batch_size=24))
# optimizer
lr = 0.001 # max learning rate
optimizer = dict(type='Adam', lr=lr, weight_decay=1e-4)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5)
# runtime settings
checkpoint_config = dict(interval=5)
runner = dict(type='EpochBasedRunner', max_epochs=150)
......@@ -2,9 +2,10 @@ from .inference import (convert_SyncBN, inference_detector,
inference_multi_modality_detector, init_detector,
show_result_meshlab)
from .test import single_gpu_test
from .train import train_model
__all__ = [
'inference_detector', 'init_detector', 'single_gpu_test',
'show_result_meshlab', 'convert_SyncBN',
'show_result_meshlab', 'convert_SyncBN', 'train_model',
'inference_multi_modality_detector'
]
......@@ -12,7 +12,7 @@ from mmdet3d.core import (Box3DMode, DepthInstance3DBoxes,
show_result)
from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_detector
from mmdet3d.models import build_model
def convert_SyncBN(config):
......@@ -52,7 +52,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
config.model.pretrained = None
convert_SyncBN(config.model)
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
model = build_model(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']:
......
import mmcv
import os
import torch
from mmcv.image import tensor2imgs
from os import path as osp
from mmdet3d.models import Base3DDetector, Base3DSegmentor
def single_gpu_test(model,
......@@ -35,11 +37,11 @@ def single_gpu_test(model,
result = model(return_loss=False, rescale=True, **data)
if show:
# Visualize the results of MMdetection3D model
# Visualize the results of MMDetection3D model
# 'show_results' is MMdetection3D visualization API
if hasattr(model.module, 'show_results'):
if isinstance(model.module, (Base3DDetector, Base3DSegmentor)):
model.module.show_results(data, result, out_dir)
# Visualize the results of MMdetection model
# Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API
else:
batch_size = len(result)
......@@ -60,8 +62,7 @@ def single_gpu_test(model,
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = os.path.join(out_dir,
img_meta['ori_filename'])
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
......
from mmdet.apis import train_detector
from mmseg.apis import train_segmentor
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""A function wrapper for launching model training according to cfg.
Because we need different eval_hook in runner. Should be deprecated in the
future.
"""
if cfg.model.type in ['EncoderDecoder3D']:
train_segmentor(
model,
dataset,
cfg,
distributed=distributed,
validate=validate,
timestamp=timestamp,
meta=meta)
else:
train_detector(
model,
dataset,
cfg,
distributed=distributed,
validate=validate,
timestamp=timestamp,
meta=meta)
......@@ -6,11 +6,13 @@ from os import path as osp
from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .pipelines import Compose
from .utils import get_loading_pipeline
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class Custom3DSegDataset(Dataset):
"""Customized 3D dataset for semantic segmentation task.
......@@ -143,7 +145,7 @@ class Custom3DSegDataset(Dataset):
results['pts_seg_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
results['gt_bboxes_3d'] = []
results['bbox3d_fields'] = []
def prepare_train_data(self, index):
"""Training data preparation.
......
......@@ -3,11 +3,13 @@ from os import path as osp
from mmdet3d.core import show_seg_result
from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class _S3DISSegDataset(Custom3DSegDataset):
r"""S3DIS Dataset for Semantic Segmentation Task.
......
......@@ -6,6 +6,7 @@ from os import path as osp
from mmdet3d.core import show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose
......@@ -196,6 +197,7 @@ class ScanNetDataset(Custom3DDataset):
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset):
r"""ScanNet Dataset for Semantic Segmentation Task.
......
from .backbones import * # noqa: F401,F403
from .builder import (FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS,
build_backbone, build_detector, build_fusion_layer,
build_head, build_loss, build_middle_encoder, build_neck,
build_roi_extractor, build_shared_head,
build_voxel_encoder)
build_head, build_loss, build_middle_encoder,
build_model, build_neck, build_roi_extractor,
build_shared_head, build_voxel_encoder)
from .decode_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
......@@ -13,11 +13,12 @@ from .middle_encoders import * # noqa: F401,F403
from .model_utils import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .voxel_encoders import * # noqa: F401,F403
__all__ = [
'VOXEL_ENCODERS', 'MIDDLE_ENCODERS', 'FUSION_LAYERS', 'build_backbone',
'build_neck', 'build_roi_extractor', 'build_shared_head', 'build_head',
'build_loss', 'build_detector', 'build_fusion_layer',
'build_loss', 'build_detector', 'build_fusion_layer', 'build_model',
'build_middle_encoder', 'build_voxel_encoder'
]
......@@ -3,6 +3,7 @@ from mmcv.utils import Registry
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build)
from mmseg.models.builder import SEGMENTORS
VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
......@@ -52,6 +53,31 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
"""Build segmentor."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_model(cfg, train_cfg=None, test_cfg=None):
"""A function warpper for building 3D detector or segmentor according to
cfg.
Should be deprecated in the future.
"""
if cfg.type in ['EncoderDecoder3D']:
return build_segmentor(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
else:
return build_detector(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
def build_voxel_encoder(cfg):
"""Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS)
......
......@@ -65,13 +65,13 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
def forward_train(self, inputs, img_metas, pts_semantic_mask, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level point features.
inputs (list[torch.Tensor]): List of multi-level point features.
img_metas (list[dict]): Meta information of each sample.
gt_semantic_seg (torch.Tensor): Semantic segmentation masks
pts_semantic_mask (torch.Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
......@@ -79,7 +79,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
losses = self.losses(seg_logits, pts_semantic_mask)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
......
from .base import Base3DSegmentor
from .encoder_decoder import EncoderDecoder3D
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D']
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp
from mmdet3d.core import show_seg_result
from mmseg.models.segmentors import BaseSegmentor
class Base3DSegmentor(BaseSegmentor):
"""Base class for 3D segmentors.
The main difference with `BaseSegmentor` is that we modify the keys in
data_dict and use a 3D seg specific visualization function.
"""
def forward_test(self, points, img_metas, **kwargs):
"""Calls either simple_test or aug_test depending on the length of
outer list of points. If len(points) == 1, call simple_test. Otherwise
call aug_test to aggregate the test results by e.g. voting.
Args:
points (list[list[torch.Tensor]]): the outer list indicates
test-time augmentations and inner torch.Tensor should have a
shape BXNxC, which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
num_augs = len(points)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(points)}) != '
f'num of image meta ({len(img_metas)})')
if num_augs == 1:
return self.simple_test(points[0], img_metas[0], **kwargs)
else:
return self.aug_test(points, img_metas, **kwargs)
@auto_fp16(apply_to=('points'))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, point and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, point and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def show_results(self,
data,
result,
palette=None,
out_dir=None,
ignore_index=None):
"""Results visualization.
Args:
data (list[dict]): Input points and the information of the sample.
result (list[dict]): Prediction results.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
out_dir (str): Output directory of visualization result.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
if palette is None:
if self.PALETTE is None:
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
else:
palette = self.PALETTE
palette = np.array(palette)
for batch_id in range(len(result)):
if isinstance(data['points'][0], DC):
points = data['points'][0]._data[0][batch_id].numpy()
elif mmcv.is_list_of(data['points'][0], torch.Tensor):
points = data['points'][0][batch_id]
else:
ValueError(f"Unsupported data type {type(data['points'][0])} "
f'for visualization!')
if isinstance(data['img_metas'][0], DC):
pts_filename = data['img_metas'][0]._data[0][batch_id][
'pts_filename']
elif mmcv.is_list_of(data['img_metas'][0], dict):
pts_filename = data['img_metas'][0][batch_id]['pts_filename']
else:
ValueError(
f"Unsupported data type {type(data['img_metas'][0])} "
f'for visualization!')
file_name = osp.split(pts_filename)[-1].split('.')[0]
pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy()
show_seg_result(points, None, pred_sem_mask, out_dir, file_name,
palette, ignore_index)
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from mmseg.core import add_prefix
from mmseg.models import SEGMENTORS
from ..builder import build_backbone, build_head, build_neck
from .base import Base3DSegmentor
@SEGMENTORS.register_module()
class EncoderDecoder3D(Base3DSegmentor):
"""3D Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be thrown during inference.
"""
def __init__(self,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(EncoderDecoder3D, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head, \
'3D EncoderDecoder Segmentor should have a decode_head'
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = build_head(decode_head)
self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(build_head(head_cfg))
else:
self.auxiliary_head = build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder3D, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, points):
"""Extract features from points."""
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, points, img_metas):
"""Encode points with backbone and decode into a semantic segmentation
map of the same size as input.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample.
Returns:
torch.Tensor: Segmentation logits of shape [B, num_classes, N].
"""
x = self.extract_feat(points)
out = self._decode_head_forward_test(x, img_metas)
return out
def _decode_head_forward_train(self, x, img_metas, pts_semantic_mask):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas,
pts_semantic_mask,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return seg_logits
def _auxiliary_head_forward_train(self, x, img_metas, pts_semantic_mask):
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas,
pts_semantic_mask,
self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.forward_train(
x, img_metas, pts_semantic_mask, self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def forward_dummy(self, points):
"""Dummy forward function."""
seg_logit = self.encode_decode(points, None)
return seg_logit
def forward_train(self, points, img_metas, pts_semantic_mask):
"""Forward function for training.
Args:
points (list[torch.Tensor]): List of points of shape [N, C].
img_metas (list): Image metas.
pts_semantic_mask (list[torch.Tensor]): List of point-wise semantic
labels of shape [N].
Returns:
dict[str, Tensor]: Losses.
"""
points_cat = torch.stack(points)
pts_semantic_mask_cat = torch.stack(pts_semantic_mask)
# extract features using backbone
x = self.extract_feat(points_cat)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas,
pts_semantic_mask_cat)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, img_metas, pts_semantic_mask_cat)
losses.update(loss_aux)
return losses
@staticmethod
def _input_generation(coords,
patch_center,
coord_max,
feats,
use_normalized_coord=False):
"""Generating model input.
Generate input by subtracting patch center and adding additional \
features. Currently support colors and normalized xyz as features.
Args:
coords (torch.Tensor): Sampled 3D point coordinate of shape [S, 3].
patch_center (torch.Tensor): Center coordinate of the patch.
coord_max (torch.Tensor): Max coordinate of all 3D points.
feats (torch.Tensor): Features of sampled points of shape [S, C].
use_normalized_coord (bool, optional): Whether to use normalized \
xyz as additional features. Defaults to False.
Returns:
torch.Tensor: The generated input data of shape [S, 3+C'].
"""
# subtract patch center, the z dimension is not centered
centered_coords = coords.clone()
centered_coords[:, 0] -= patch_center[0]
centered_coords[:, 1] -= patch_center[1]
# normalized coordinates as extra features
if use_normalized_coord:
normalized_coord = coords / coord_max
feats = torch.cat([feats, normalized_coord], dim=1)
points = torch.cat([centered_coords, feats], dim=1)
return points
def _sliding_patch_generation(self,
points,
num_points,
block_size,
sample_rate=0.5,
use_normalized_coord=False,
eps=1e-3):
"""Sampling points in a sliding window fashion.
First sample patches to cover all the input points.
Then sample points in each patch to batch points of a certain number.
Args:
points (torch.Tensor): Input points of shape [N, 3+C].
num_points (int): Number of points to be sampled in each patch.
block_size (float, optional): Size of a patch to sample.
sample_rate (float, optional): Stride used in sliding patch.
Defaults to 0.5.
use_normalized_coord (bool, optional): Whether to use normalized \
xyz as additional features. Defaults to False.
eps (float, optional): A value added to patch boundary to guarantee
points coverage. Default 1e-3.
Returns:
np.ndarray | np.ndarray:
- patch_points (torch.Tensor): Points of different patches of \
shape [K, N, 3+C].
- patch_idxs (torch.Tensor): Index of each point in \
`patch_points`, of shape [K, N].
"""
device = points.device
# we assume the first three dims are points' 3D coordinates
# and the rest dims are their per-point features
coords = points[:, :3]
feats = points[:, 3:]
coord_max = coords.max(0)[0]
coord_min = coords.min(0)[0]
stride = block_size * sample_rate
num_grid_x = int(
torch.ceil((coord_max[0] - coord_min[0] - block_size) /
stride).item() + 1)
num_grid_y = int(
torch.ceil((coord_max[1] - coord_min[1] - block_size) /
stride).item() + 1)
patch_points, patch_idxs = [], []
for idx_y in range(num_grid_y):
s_y = coord_min[1] + idx_y * stride
e_y = torch.min(s_y + block_size, coord_max[1])
s_y = e_y - block_size
for idx_x in range(num_grid_x):
s_x = coord_min[0] + idx_x * stride
e_x = torch.min(s_x + block_size, coord_max[0])
s_x = e_x - block_size
# extract points within this patch
cur_min = torch.tensor([s_x, s_y, coord_min[2]]).to(device)
cur_max = torch.tensor([e_x, e_y, coord_max[2]]).to(device)
cur_choice = ((coords >= cur_min - eps) &
(coords <= cur_max + eps)).all(dim=1)
if not cur_choice.any(): # no points in this patch
continue
# sample points in this patch to multiple batches
cur_center = cur_min + block_size / 2.0
point_idxs = torch.nonzero(cur_choice, as_tuple=True)[0]
num_batch = int(np.ceil(point_idxs.shape[0] / num_points))
point_size = int(num_batch * num_points)
replace = point_size > 2 * point_idxs.shape[0]
num_repeat = point_size - point_idxs.shape[0]
if replace: # duplicate
point_idxs_repeat = point_idxs[torch.randint(
0, point_idxs.shape[0],
size=(num_repeat, )).to(device)]
else:
point_idxs_repeat = point_idxs[torch.randperm(
point_idxs.shape[0])[:num_repeat]]
choices = torch.cat([point_idxs, point_idxs_repeat], dim=0)
choices = choices[torch.randperm(choices.shape[0])]
# construct model input
point_batches = self._input_generation(
coords[choices],
cur_center,
coord_max,
feats[choices],
use_normalized_coord=use_normalized_coord)
patch_points.append(point_batches)
patch_idxs.append(choices)
patch_points = torch.cat(patch_points, dim=0)
patch_idxs = torch.cat(patch_idxs, dim=0)
# make sure all points are sampled at least once
assert torch.unique(patch_idxs).shape[0] == points.shape[0], \
'some points are not sampled in sliding inference'
return patch_points, patch_idxs
def slide_inference(self, point, img_meta, rescale):
"""Inference by sliding-window with overlap.
Args:
point (torch.Tensor): Input points of shape [N, 3+C].
img_meta (dict): Meta information of input sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Returns:
Tensor: The output segmentation map of shape [num_classes, N].
"""
num_points = self.test_cfg.num_points
block_size = self.test_cfg.block_size
sample_rate = self.test_cfg.sample_rate
use_normalized_coord = self.test_cfg.use_normalized_coord
batch_size = self.test_cfg.batch_size * num_points
# patch_points is of shape [K*N, 3+C], patch_idxs is of shape [K*N]
patch_points, patch_idxs = self._sliding_patch_generation(
point, num_points, block_size, sample_rate, use_normalized_coord)
feats_dim = patch_points.shape[1]
seg_logits = [] # save patch predictions
for batch_idx in range(0, patch_points.shape[0], batch_size):
batch_points = patch_points[batch_idx:batch_idx + batch_size]
batch_points = batch_points.view(-1, num_points, feats_dim)
# batch_seg_logit is of shape [B, num_classes, N]
batch_seg_logit = self.encode_decode(batch_points, img_meta)
batch_seg_logit = batch_seg_logit.transpose(1, 2).contiguous()
seg_logits.append(batch_seg_logit.view(-1, self.num_classes))
# aggregate per-point logits by indexing sum and dividing count
seg_logits = torch.cat(seg_logits, dim=0) # [K*N, num_classes]
expand_patch_idxs = patch_idxs.unsqueeze(1).repeat(1, self.num_classes)
preds = point.new_zeros((point.shape[0], self.num_classes)).\
scatter_add_(dim=0, index=expand_patch_idxs, src=seg_logits)
count_mat = torch.bincount(patch_idxs)
preds = preds / count_mat[:, None]
# TODO: if rescale and voxelization segmentor
return preds.transpose(0, 1) # to [num_classes, K*N]
def whole_inference(self, points, img_metas, rescale):
"""Inference with full scene (one forward pass without sliding)."""
seg_logit = self.encode_decode(points, img_metas)
# TODO: if rescale and voxelization segmentor
return seg_logit
def inference(self, points, img_metas, rescale):
"""Inference with slide/whole style.
Args:
points (torch.Tensor): Input points of shape [B, N, 3+C].
img_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Returns:
Tensor: The output segmentation map.
"""
assert self.test_cfg.mode in ['slide', 'whole']
if self.test_cfg.mode == 'slide':
seg_logit = torch.stack([
self.slide_inference(point, img_meta, rescale)
for point, img_meta in zip(points, img_metas)
], 0)
else:
seg_logit = self.whole_inference(points, img_metas, rescale)
output = F.softmax(seg_logit, dim=1)
return output
def simple_test(self, points, img_metas, rescale=True):
"""Simple test with single scene.
Args:
points (list[torch.Tensor]): List of points of shape [N, 3+C].
img_metas (list[dict]): Meta information of each sample.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
list[dict]: The output prediction result with following keys:
- semantic_mask (Tensor): Segmentation mask of shape [N].
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
seg_pred = []
for point, img_meta in zip(points, img_metas):
seg_prob = self.inference(point.unsqueeze(0), [img_meta],
rescale)[0]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred
def aug_test(self, points, img_metas, rescale=True):
"""Test with augmentations.
Args:
points (list[torch.Tensor]): List of points of shape [B, N, 3+C].
img_metas (list[list[dict]]): Meta information of each sample.
Outer list are different samples while inner is different augs.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
list[dict]: The output prediction result with following keys:
- semantic_mask (Tensor): Segmentation mask of shape [N].
"""
# in aug_test, one scene going through different augmentations could
# have the same number of points and are stacked as a batch
# to save memory, we get augmented seg logit inplace
seg_pred = []
for point, img_meta in zip(points, img_metas):
seg_prob = self.inference(point, img_meta, rescale)
seg_prob = seg_prob.mean(0) # [num_classes, N]
seg_map = seg_prob.argmax(0) # [N]
# to cpu tensor for consistency with det3d
seg_map = seg_map.cpu()
seg_pred.append(seg_map)
# warp in dict
seg_pred = [dict(semantic_mask=seg_map) for seg_map in seg_pred]
return seg_pred
......@@ -82,12 +82,14 @@ class Points_Sampler(nn.Module):
if fps_sample_range == -1:
sample_points_xyz = points_xyz[:, last_fps_end_index:]
sample_features = features[:, :, last_fps_end_index:]
sample_features = features[:, :, last_fps_end_index:] if \
features is not None else None
else:
sample_points_xyz = \
points_xyz[:, last_fps_end_index:fps_sample_range]
sample_features = \
features[:, :, last_fps_end_index:fps_sample_range]
features[:, :, last_fps_end_index:fps_sample_range] if \
features is not None else None
fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
npoint)
......@@ -125,6 +127,8 @@ class FFPS_Sampler(nn.Module):
def forward(self, points, features, npoint):
"""Sampling points with F-FPS."""
assert features is not None, \
'feature input to FFPS_Sampler should not be None'
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist(
features_for_fps, features_for_fps, norm=False)
......@@ -143,6 +147,8 @@ class FS_Sampler(nn.Module):
def forward(self, points, features, npoint):
"""Sampling points with FS_Sampling."""
assert features is not None, \
'feature input to FS_Sampler should not be None'
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist(
features_for_fps, features_for_fps, norm=False)
......
from mmcv.utils import Registry, build_from_cfg, print_log
from mmdet.utils import get_root_logger
from .collect_env import collect_env
from .logger import get_root_logger
__all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'print_log'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment