"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "f63a62b82ea1de1e75996c70e354cbcc6d78464a"
Unverified Commit 276b2648 authored by Danila Rukhovich's avatar Danila Rukhovich Committed by GitHub
Browse files

[Feature] Support MinkowskiEngine with MinkResNet (#1422)

* add minkowski engine with MinkResNet

* fix docstring

* mmdet -> mmdet3d for registry import

* fix import

* fix documentation
parent 0287048a
...@@ -76,14 +76,14 @@ You can check the supported CUDA version for precompiled packages on the [PyTorc ...@@ -76,14 +76,14 @@ You can check the supported CUDA version for precompiled packages on the [PyTorc
`E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install `E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install
PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1. PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1.
```python ```shell
conda install pytorch==1.5.0 cudatoolkit=10.1 torchvision==0.6.0 -c pytorch conda install pytorch==1.5.0 cudatoolkit=10.1 torchvision==0.6.0 -c pytorch
``` ```
`E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install `E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install
PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2. PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2.
```python ```shell
conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch
``` ```
...@@ -192,6 +192,13 @@ you can install it before installing MMCV. ...@@ -192,6 +192,13 @@ you can install it before installing MMCV.
4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`. 4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`.
We also support Minkowski Engine as a sparse convolution backend. If necessary please follow original [installation guide](https://github.com/NVIDIA/MinkowskiEngine#installation) or use `pip`:
```shell
conda install openblas-devel -c anaconda
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=/opt/conda/include" --install-option="--blas=openblas"
```
5. The code can not be built for CPU only environment (where CUDA isn't available) for now. 5. The code can not be built for CPU only environment (where CUDA isn't available) for now.
## Another option: Docker Image ## Another option: Docker Image
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .dgcnn import DGCNNBackbone from .dgcnn import DGCNNBackbone
from .dla import DLANet from .dla import DLANet
from .mink_resnet import MinkResNet
from .multi_backbone import MultiBackbone from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG from .pointnet2_sa_msg import PointNet2SAMSG
...@@ -11,5 +12,5 @@ from .second import SECOND ...@@ -11,5 +12,5 @@ from .second import SECOND
__all__ = [ __all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet' 'MultiBackbone', 'DLANet', 'MinkResNet'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
# Follow https://github.com/NVIDIA/MinkowskiEngine/blob/master/examples/resnet.py # noqa
# and mmcv.cnn.ResNet
try:
import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
except ImportError:
import warnings
warnings.warn(
'Please follow `getting_started.md` to install MinkowskiEngine.`')
# blocks are used in the static part of MinkResNet
BasicBlock, Bottleneck = None, None
import torch.nn as nn
from mmdet3d.models.builder import BACKBONES
@BACKBONES.register_module()
class MinkResNet(nn.Module):
r"""Minkowski ResNet backbone. See `4D Spatio-Temporal ConvNets
<https://arxiv.org/abs/1904.08755>`_ for more details.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (ont): Number of input channels, 3 for RGB.
num_stages (int, optional): Resnet stages. Default: 4.
pool (bool, optional): Add max pooling after first conv if True.
Default: True.
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, in_channels, num_stages=4, pool=True):
super(MinkResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert 4 >= num_stages >= 1
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.num_stages = num_stages
self.pool = pool
self.inplanes = 64
self.conv1 = ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=3, stride=2, dimension=3)
# May be BatchNorm is better, but we follow original implementation.
self.norm1 = ME.MinkowskiInstanceNorm(self.inplanes)
self.relu = ME.MinkowskiReLU(inplace=True)
if self.pool:
self.maxpool = ME.MinkowskiMaxPooling(
kernel_size=2, stride=2, dimension=3)
for i, num_blocks in enumerate(stage_blocks):
setattr(
self, f'layer{i}',
self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2))
def init_weights(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(
m.kernel, mode='fan_out', nonlinearity='relu')
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def _make_layer(self, block, planes, blocks, stride):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
ME.MinkowskiConvolution(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
dimension=3),
ME.MinkowskiBatchNorm(planes * block.expansion))
layers = []
layers.append(
block(
self.inplanes,
planes,
stride=stride,
downsample=downsample,
dimension=3))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, stride=1, dimension=3))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward pass of ResNet.
Args:
x (ME.SparseTensor): Input sparse tensor.
Returns:
list[ME.SparseTensor]: Output sparse tensors.
"""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
if self.pool:
x = self.maxpool(x)
outs = []
for i in range(self.num_stages):
x = getattr(self, f'layer{i}')(x)
outs.append(x)
return outs
...@@ -353,3 +353,55 @@ def test_dla_net(): ...@@ -353,3 +353,55 @@ def test_dla_net():
assert results[3].shape == torch.Size([4, 128, 4, 4]) assert results[3].shape == torch.Size([4, 128, 4, 4])
assert results[4].shape == torch.Size([4, 256, 2, 2]) assert results[4].shape == torch.Size([4, 256, 2, 2])
assert results[5].shape == torch.Size([4, 512, 1, 1]) assert results[5].shape == torch.Size([4, 512, 1, 1])
def test_mink_resnet():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
try:
import MinkowskiEngine as ME
except ImportError:
pytest.skip('test requires MinkowskiEngine installation')
coordinates, features = [], []
np.random.seed(42)
# batch of 2 point clouds
for i in range(2):
c = torch.from_numpy(np.random.rand(500, 3) * 100)
coordinates.append(c.float().cuda())
f = torch.from_numpy(np.random.rand(500, 3))
features.append(f.float().cuda())
tensor_coordinates, tensor_features = ME.utils.sparse_collate(
coordinates, features)
x = ME.SparseTensor(
features=tensor_features, coordinates=tensor_coordinates)
# MinkResNet34 with 4 outputs
cfg = dict(type='MinkResNet', depth=34, in_channels=3)
self = build_backbone(cfg).cuda()
self.init_weights()
y = self(x)
assert len(y) == 4
assert y[0].F.shape == torch.Size([900, 64])
assert y[0].tensor_stride[0] == 8
assert y[1].F.shape == torch.Size([472, 128])
assert y[1].tensor_stride[0] == 16
assert y[2].F.shape == torch.Size([105, 256])
assert y[2].tensor_stride[0] == 32
assert y[3].F.shape == torch.Size([16, 512])
assert y[3].tensor_stride[0] == 64
# MinkResNet50 with 2 outputs
cfg = dict(
type='MinkResNet', depth=34, in_channels=3, num_stages=2, pool=False)
self = build_backbone(cfg).cuda()
self.init_weights()
y = self(x)
assert len(y) == 2
assert y[0].F.shape == torch.Size([985, 64])
assert y[0].tensor_stride[0] == 4
assert y[1].F.shape == torch.Size([900, 128])
assert y[1].tensor_stride[0] == 8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment