".github/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "fea2b7bd37421a6cd66757f04f10862125e4255c"
Commit 720caedf authored by dingchang's avatar dingchang Committed by Tai-Wang
Browse files

[Feature] Support DGCNN (v1.0.0.dev0) (#896)

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* fix typo

* fix typo

* fix typo

* del gf&fa registry (wo reuse pointnet module)

* fix typo

* add benchmark and add copyright header (for DGCNN only)

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* support dgcnn
parent bf4e71c2
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
def test_dgcnn_gf_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import DGCNNGFModule
self = DGCNNGFModule(
mlp_channels=[18, 64, 64],
num_sample=20,
knn_mod='D-KNN',
radius=None,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='ReLU'),
pool_mod='max').cuda()
assert self.mlps[0].layer0.conv.in_channels == 18
assert self.mlps[0].layer0.conv.out_channels == 64
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
# (B, N, C)
xyz = torch.from_numpy(xyz).view(1, -1, 3).cuda()
points = xyz.repeat([1, 1, 3])
# test forward
new_points = self(points)
assert new_points.shape == torch.Size([1, 200, 64])
# test F-KNN mod
self = DGCNNGFModule(
mlp_channels=[6, 64, 64],
num_sample=20,
knn_mod='F-KNN',
radius=None,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='ReLU'),
pool_mod='max').cuda()
# test forward
new_points = self(xyz)
assert new_points.shape == torch.Size([1, 200, 64])
# test ball query
self = DGCNNGFModule(
mlp_channels=[6, 64, 64],
num_sample=20,
knn_mod='F-KNN',
radius=0.2,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='ReLU'),
pool_mod='max').cuda()
def test_dgcnn_fa_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import DGCNNFAModule
self = DGCNNFAModule(mlp_channels=[24, 16]).cuda()
assert self.mlps.layer0.conv.in_channels == 24
assert self.mlps.layer0.conv.out_channels == 16
points = [torch.rand(1, 200, 12).float().cuda() for _ in range(3)]
fa_points = self(points)
assert fa_points.shape == torch.Size([1, 200, 40])
def test_dgcnn_fp_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import DGCNNFPModule
self = DGCNNFPModule(mlp_channels=[24, 16]).cuda()
assert self.mlps.layer0.conv.in_channels == 24
assert self.mlps.layer0.conv.out_channels == 16
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin',
np.float32).reshape((-1, 6))
# (B, N, 3)
xyz = torch.from_numpy(xyz).view(1, -1, 3).cuda()
points = xyz.repeat([1, 1, 8]).cuda()
fp_points = self(points)
assert fp_points.shape == torch.Size([1, 200, 16])
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.cnn.bricks import ConvModule
from mmdet3d.models.builder import build_head
def test_dgcnn_decode_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
dgcnn_decode_head_cfg = dict(
type='DGCNNHead',
fp_channels=(1024, 512),
channels=256,
num_classes=13,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=13)
self = build_head(dgcnn_decode_head_cfg)
self.cuda()
assert isinstance(self.conv_seg, torch.nn.Conv1d)
assert self.conv_seg.in_channels == 256
assert self.conv_seg.out_channels == 13
assert self.conv_seg.kernel_size == (1, )
assert isinstance(self.pre_seg_conv, ConvModule)
assert isinstance(self.pre_seg_conv.conv, torch.nn.Conv1d)
assert self.pre_seg_conv.conv.in_channels == 512
assert self.pre_seg_conv.conv.out_channels == 256
assert self.pre_seg_conv.conv.kernel_size == (1, )
assert isinstance(self.pre_seg_conv.bn, torch.nn.BatchNorm1d)
assert self.pre_seg_conv.bn.num_features == 256
# test forward
fa_points = torch.rand(2, 4096, 1024).float().cuda()
input_dict = dict(fa_points=fa_points)
seg_logits = self(input_dict)
assert seg_logits.shape == torch.Size([2, 13, 4096])
# test loss
pts_semantic_mask = torch.randint(0, 13, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
# test loss with ignore_index
ignore_index_mask = torch.ones_like(pts_semantic_mask) * 13
losses = self.losses(seg_logits, ignore_index_mask)
assert losses['loss_sem_seg'].item() == 0
# test loss with class_weight
dgcnn_decode_head_cfg['loss_decode'] = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=np.random.rand(13),
loss_weight=1.0)
self = build_head(dgcnn_decode_head_cfg)
self.cuda()
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
......@@ -304,3 +304,48 @@ def test_paconv_cuda_ssg():
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
def test_dgcnn():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
dgcnn_cfg = _get_segmentor_cfg(
'dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py')
dgcnn_cfg.test_cfg.num_points = 32
self = build_segmentor(dgcnn_cfg).cuda()
points = [torch.rand(4096, 9).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 13, (4096, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(500, 6).float().cuda() * 3.0,
torch.randn(200, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 500, 6).float().cuda() * 3.0,
torch.randn(2, 200, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment