Unverified Commit 9d852f17 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PointNet++ Segmentor (#528)

* build BaseSegmentor for point sem seg

* add encoder-decoder segmentor

* update mmseg dependency

* fix linting errors

* warp predicted seg_mask in dict

* add unit test

* use build_model to wrap detector and segmentor

* fix compatibility with mmseg

* faster sliding inference

* merge master

* configs for training on ScanNet

* fix CI errors

* add comments & fix typos

* hard-code class_weight into configs

* fix logger bugs

* update segmentor unit test

* logger use mmdet3d

* use eps to replace hard-coded 1e-3

* add comments

* replace np operation with torch code

* add comments for class_weight

* add comment for BaseSegmentor.simple_test

* rewrite EncoderDecoder3D to avoid inheriting from mmseg
parent 43d79534
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO, name='mmdet3d'):
"""Get root logger and add a keyword filter to it.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet3d".
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
name (str, optional): The name of the root logger, also used as a
filter keyword. Defaults to 'mmdet3d'.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
logger = get_logger(name=name, log_file=log_file, log_level=log_level)
# add a logging filter
logging_filter = logging.Filter(name)
logging_filter.filter = lambda record: record.find(name) != -1
return logger
...@@ -58,6 +58,20 @@ def test_pointnet2_sa_ssg(): ...@@ -58,6 +58,20 @@ def test_pointnet2_sa_ssg():
assert sa_indices[1].shape == torch.Size([1, 32]) assert sa_indices[1].shape == torch.Size([1, 32])
assert sa_indices[2].shape == torch.Size([1, 16]) assert sa_indices[2].shape == torch.Size([1, 16])
# test only xyz input without features
cfg['in_channels'] = 3
self = build_backbone(cfg)
self.cuda()
ret_dict = self(xyz[..., :3])
assert len(fp_xyz) == len(fp_features) == len(fp_indices) == 3
assert len(sa_xyz) == len(sa_features) == len(sa_indices) == 3
assert fp_features[0].shape == torch.Size([1, 16, 16])
assert fp_features[1].shape == torch.Size([1, 16, 32])
assert fp_features[2].shape == torch.Size([1, 16, 100])
assert sa_features[0].shape == torch.Size([1, 3, 100])
assert sa_features[1].shape == torch.Size([1, 16, 32])
assert sa_features[2].shape == torch.Size([1, 16, 16])
def test_multi_backbone(): def test_multi_backbone():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
......
...@@ -61,12 +61,12 @@ def test_pn2_decode_head_loss(): ...@@ -61,12 +61,12 @@ def test_pn2_decode_head_loss():
assert seg_logits.shape == torch.Size([2, 20, 4096]) assert seg_logits.shape == torch.Size([2, 20, 4096])
# test loss # test loss
gt_semantic_seg = torch.randint(0, 20, (2, 4096)).long().cuda() pts_semantic_mask = torch.randint(0, 20, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, gt_semantic_seg) losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0 assert losses['loss_sem_seg'].item() > 0
# test loss with ignore_index # test loss with ignore_index
ignore_index_mask = torch.ones_like(gt_semantic_seg) * 20 ignore_index_mask = torch.ones_like(pts_semantic_mask) * 20
losses = self.losses(seg_logits, ignore_index_mask) losses = self.losses(seg_logits, ignore_index_mask)
assert losses['loss_sem_seg'].item() == 0 assert losses['loss_sem_seg'].item() == 0
...@@ -78,5 +78,5 @@ def test_pn2_decode_head_loss(): ...@@ -78,5 +78,5 @@ def test_pn2_decode_head_loss():
loss_weight=1.0) loss_weight=1.0)
self = build_head(pn2_decode_head_cfg) self = build_head(pn2_decode_head_cfg)
self.cuda() self.cuda()
losses = self.losses(seg_logits, gt_semantic_seg) losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0 assert losses['loss_sem_seg'].item() > 0
import copy
import numpy as np
import pytest
import torch
from os.path import dirname, exists, join
from mmdet3d.models.builder import build_segmentor
from mmdet.apis import set_random_seed
def _get_config_directory():
"""Find the predefined detector config directory."""
try:
# Assume we are running in the source mmdetection3d repo
repo_dpath = dirname(dirname(dirname(__file__)))
except NameError:
# For IPython development when this __file__ is not defined
import mmdet3d
repo_dpath = dirname(dirname(mmdet3d.__file__))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def _get_config_module(fname):
"""Load a configuration as a python module."""
from mmcv import Config
config_dpath = _get_config_directory()
config_fpath = join(config_dpath, fname)
config_mod = Config.fromfile(config_fpath)
return config_mod
def _get_segmentor_cfg(fname):
"""Grab configs necessary to create a segmentor.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
import mmcv
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg))
model.update(train_cfg=train_cfg)
model.update(test_cfg=test_cfg)
return model
def test_pointnet2_ssg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
pn2_ssg_cfg = _get_segmentor_cfg(
'pointnet2/pointnet2_ssg_16x2_scannet-3d-20class.py')
pn2_ssg_cfg.test_cfg.num_points = 32
self = build_segmentor(pn2_ssg_cfg).cuda()
points = [torch.rand(1024, 6).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 20, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
# test forward function
set_random_seed(0, True)
data_dict = dict(
points=points, img_metas=img_metas, pts_semantic_mask=gt_masks)
forward_losses = self.forward(return_loss=True, **data_dict)
assert np.allclose(losses['decode.loss_sem_seg'].item(),
forward_losses['decode.loss_sem_seg'].item())
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 20 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(500, 6).float().cuda() * 3.0,
torch.randn(200, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test forward function calling simple_test
with torch.no_grad():
data_dict = dict(points=[scene_points], img_metas=[img_metas])
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 500, 6).float().cuda() * 3.0,
torch.randn(2, 200, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test forward function calling aug_test
with torch.no_grad():
data_dict = dict(points=scene_points, img_metas=img_metas)
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
def test_pointnet2_msg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
pn2_msg_cfg = _get_segmentor_cfg(
'pointnet2/pointnet2_msg_16x2_scannet-3d-20class.py')
pn2_msg_cfg.test_cfg.num_points = 32
self = build_segmentor(pn2_msg_cfg).cuda()
points = [torch.rand(1024, 6).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 20, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 20 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(500, 6).float().cuda() * 3.0,
torch.randn(200, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 500, 6).float().cuda() * 3.0,
torch.randn(2, 200, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
...@@ -16,12 +16,12 @@ def _get_config_directory(): ...@@ -16,12 +16,12 @@ def _get_config_directory():
return config_dpath return config_dpath
def test_config_build_detector(): def test_config_build_model():
"""Test that all detection models defined in the configs can be """Test that all detection models defined in the configs can be
initialized.""" initialized."""
from mmcv import Config from mmcv import Config
from mmdet3d.models import build_detector from mmdet3d.models import build_model
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath)) print('Found config_dpath = {!r}'.format(config_dpath))
...@@ -46,7 +46,7 @@ def test_config_build_detector(): ...@@ -46,7 +46,7 @@ def test_config_build_detector():
if 'pretrained' in config_mod.model: if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None config_mod.model['pretrained'] = None
detector = build_detector(config_mod.model) detector = build_model(config_mod.model)
assert detector is not None assert detector is not None
if 'roi_head' in config_mod.model.keys(): if 'roi_head' in config_mod.model.keys():
......
...@@ -11,7 +11,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, ...@@ -11,7 +11,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
from mmdet3d.apis import single_gpu_test from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor from mmdet.datasets import replace_ImageToTensor
...@@ -165,7 +165,7 @@ def main(): ...@@ -165,7 +165,7 @@ def main():
# build the model and load checkpoint # build the model and load checkpoint
cfg.model.train_cfg = None cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
......
...@@ -2,7 +2,6 @@ from __future__ import division ...@@ -2,7 +2,6 @@ from __future__ import division
import argparse import argparse
import copy import copy
import logging
import mmcv import mmcv
import os import os
import time import time
...@@ -12,11 +11,14 @@ from mmcv import Config, DictAction ...@@ -12,11 +11,14 @@ from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from os import path as osp from os import path as osp
from mmdet3d import __version__ from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_detector from mmdet3d.models import build_model
from mmdet3d.utils import collect_env, get_root_logger from mmdet3d.utils import collect_env, get_root_logger
from mmdet.apis import set_random_seed, train_detector from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version
def parse_args(): def parse_args():
...@@ -139,11 +141,15 @@ def main(): ...@@ -139,11 +141,15 @@ def main():
# init the logger before other steps # init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # specify logger name, if we still use 'mmdet', the output info will be
# filtered and won't be saved in the log_file
# add a logging filter # TODO: ugly workaround to judge whether we are training det or seg model
logging_filter = logging.Filter('mmdet') if cfg.model.type in ['EncoderDecoder3D']:
logging_filter.filter = lambda record: record.find('mmdet') != -1 logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# init the meta dict to record some important information such as # init the meta dict to record some important information such as
# environment info and seed, which will be logged # environment info and seed, which will be logged
...@@ -170,7 +176,7 @@ def main(): ...@@ -170,7 +176,7 @@ def main():
meta['seed'] = args.seed meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config) meta['exp_name'] = osp.basename(args.config)
model = build_detector( model = build_model(
cfg.model, cfg.model,
train_cfg=cfg.get('train_cfg'), train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')) test_cfg=cfg.get('test_cfg'))
...@@ -193,12 +199,16 @@ def main(): ...@@ -193,12 +199,16 @@ def main():
# save mmdet version, config file content and class names in # save mmdet version, config file content and class names in
# checkpoints as meta data # checkpoints as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, mmdet_version=mmdet_version,
mmseg_version=mmseg_version,
mmdet3d_version=mmdet3d_version,
config=cfg.pretty_text, config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES) CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE # for segmentors
if hasattr(datasets[0], 'PALETTE') else None)
# add an attribute for visualization convenience # add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES model.CLASSES = datasets[0].CLASSES
train_detector( train_model(
model, model,
datasets, datasets,
cfg, cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment