Unverified Commit 9d852f17 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PointNet++ Segmentor (#528)

* build BaseSegmentor for point sem seg

* add encoder-decoder segmentor

* update mmseg dependency

* fix linting errors

* warp predicted seg_mask in dict

* add unit test

* use build_model to wrap detector and segmentor

* fix compatibility with mmseg

* faster sliding inference

* merge master

* configs for training on ScanNet

* fix CI errors

* add comments & fix typos

* hard-code class_weight into configs

* fix logger bugs

* update segmentor unit test

* logger use mmdet3d

* use eps to replace hard-coded 1e-3

* add comments

* replace np operation with torch code

* add comments for class_weight

* add comment for BaseSegmentor.simple_test

* rewrite EncoderDecoder3D to avoid inheriting from mmseg
parent 43d79534
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO, name='mmdet3d'):
"""Get root logger and add a keyword filter to it.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet3d".
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
name (str, optional): The name of the root logger, also used as a
filter keyword. Defaults to 'mmdet3d'.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
logger = get_logger(name=name, log_file=log_file, log_level=log_level)
# add a logging filter
logging_filter = logging.Filter(name)
logging_filter.filter = lambda record: record.find(name) != -1
return logger
......@@ -58,6 +58,20 @@ def test_pointnet2_sa_ssg():
assert sa_indices[1].shape == torch.Size([1, 32])
assert sa_indices[2].shape == torch.Size([1, 16])
# test only xyz input without features
cfg['in_channels'] = 3
self = build_backbone(cfg)
self.cuda()
ret_dict = self(xyz[..., :3])
assert len(fp_xyz) == len(fp_features) == len(fp_indices) == 3
assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 3
assert fp_features[0].shape == torch.Size([1, 16, 16])
assert fp_features[1].shape == torch.Size([1, 16, 32])
assert fp_features[2].shape == torch.Size([1, 16, 100])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 16, 32])
assert sa_features[2].shape == torch.Size([1, 16, 16])
def test_multi_backbone():
if not torch.cuda.is_available():
......
......@@ -61,12 +61,12 @@ def test_pn2_decode_head_loss():
assert seg_logits.shape == torch.Size([2, 20, 4096])
# test loss
gt_semantic_seg = torch.randint(0, 20, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, gt_semantic_seg)
pts_semantic_mask = torch.randint(0, 20, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
# test loss with ignore_index
ignore_index_mask = torch.ones_like(gt_semantic_seg) * 20
ignore_index_mask = torch.ones_like(pts_semantic_mask) * 20
losses = self.losses(seg_logits, ignore_index_mask)
assert losses['loss_sem_seg'].item() == 0
......@@ -78,5 +78,5 @@ def test_pn2_decode_head_loss():
loss_weight=1.0)
self = build_head(pn2_decode_head_cfg)
self.cuda()
losses = self.losses(seg_logits, gt_semantic_seg)
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
import copy
import numpy as np
import pytest
import torch
from os.path import dirname, exists, join
from mmdet3d.models.builder import build_segmentor
from mmdet.apis import set_random_seed
def _get_config_directory():
"""Find the predefined detector config directory."""
try:
# Assume we are running in the source mmdetection3d repo
repo_dpath = dirname(dirname(dirname(__file__)))
except NameError:
# For IPython development when this __file__ is not defined
import mmdet3d
repo_dpath = dirname(dirname(mmdet3d.__file__))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def _get_config_module(fname):
"""Load a configuration as a python module."""
from mmcv import Config
config_dpath = _get_config_directory()
config_fpath = join(config_dpath, fname)
config_mod = Config.fromfile(config_fpath)
return config_mod
def _get_segmentor_cfg(fname):
"""Grab configs necessary to create a segmentor.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
model.update(train_cfg=train_cfg)
model.update(test_cfg=test_cfg)
return model
def test_pointnet2_ssg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
pn2_ssg_cfg = _get_segmentor_cfg(
'pointnet2/pointnet2_ssg_16x2_scannet-3d-20class.py')
pn2_ssg_cfg.test_cfg.num_points = 32
self = build_segmentor(pn2_ssg_cfg).cuda()
points = [torch.rand(1024, 6).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 20, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
# test forward function
set_random_seed(0, True)
data_dict = dict(
points=points, img_metas=img_metas, pts_semantic_mask=gt_masks)
forward_losses = self.forward(return_loss=True, **data_dict)
assert np.allclose(losses['decode.loss_sem_seg'].item(),
forward_losses['decode.loss_sem_seg'].item())
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 20 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(500, 6).float().cuda() * 3.0,
torch.randn(200, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test forward function calling simple_test
with torch.no_grad():
data_dict = dict(points=[scene_points], img_metas=[img_metas])
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 500, 6).float().cuda() * 3.0,
torch.randn(2, 200, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test forward function calling aug_test
with torch.no_grad():
data_dict = dict(points=scene_points, img_metas=img_metas)
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
def test_pointnet2_msg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
pn2_msg_cfg = _get_segmentor_cfg(
'pointnet2/pointnet2_msg_16x2_scannet-3d-20class.py')
pn2_msg_cfg.test_cfg.num_points = 32
self = build_segmentor(pn2_msg_cfg).cuda()
points = [torch.rand(1024, 6).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 20, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 20 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(500, 6).float().cuda() * 3.0,
torch.randn(200, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 500, 6).float().cuda() * 3.0,
torch.randn(2, 200, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
......@@ -16,12 +16,12 @@ def _get_config_directory():
return config_dpath
def test_config_build_detector():
def test_config_build_model():
"""Test that all detection models defined in the configs can be
initialized."""
from mmcv import Config
from mmdet3d.models import build_detector
from mmdet3d.models import build_model
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
......@@ -46,7 +46,7 @@ def test_config_build_detector():
if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None
detector = build_detector(config_mod.model)
detector = build_model(config_mod.model)
assert detector is not None
if 'roi_head' in config_mod.model.keys():
......
......@@ -11,7 +11,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector
from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor
......@@ -165,7 +165,7 @@ def main():
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
......
......@@ -2,7 +2,6 @@ from __future__ import division
import argparse
import copy
import logging
import mmcv
import os
import time
......@@ -12,11 +11,14 @@ from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from os import path as osp
from mmdet3d import __version__
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_detector
from mmdet3d.models import build_model
from mmdet3d.utils import collect_env, get_root_logger
from mmdet.apis import set_random_seed, train_detector
from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version
def parse_args():
......@@ -139,11 +141,15 @@ def main():
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# add a logging filter
logging_filter = logging.Filter('mmdet')
logging_filter.filter = lambda record: record.find('mmdet') != -1
# specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# TODO: ugly workaround to judge whether we are training det or seg model
if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
......@@ -170,7 +176,7 @@ def main():
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
......@@ -193,12 +199,16 @@ def main():
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__,
mmdet_version=mmdet_version,
mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES)
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
train_model(
model,
datasets,
cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment