Unverified Commit b6571d22 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix spconv block (#2531)

parent 22aaa47f
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union from typing import Optional, Tuple, Union
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
...@@ -35,6 +35,7 @@ class SparseBottleneck(Bottleneck, SparseModule): ...@@ -35,6 +35,7 @@ class SparseBottleneck(Bottleneck, SparseModule):
stride (int or Tuple[int]): Stride of the first block. Defaults to 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): Down sample module for block. downsample (Module, optional): Down sample module for block.
Defaults to None. Defaults to None.
indice_key (str): Indice key for spconv. Default to None.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None. convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
...@@ -48,10 +49,16 @@ class SparseBottleneck(Bottleneck, SparseModule): ...@@ -48,10 +49,16 @@ class SparseBottleneck(Bottleneck, SparseModule):
planes: int, planes: int,
stride: Union[int, Tuple[int]] = 1, stride: Union[int, Tuple[int]] = 1,
downsample: nn.Module = None, downsample: nn.Module = None,
indice_key=None,
conv_cfg: OptConfigType = None, conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None) -> None: norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self) SparseModule.__init__(self)
if conv_cfg is None:
conv_cfg = dict(type='SubMConv3d')
conv_cfg.setdefault('indice_key', indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
Bottleneck.__init__( Bottleneck.__init__(
self, self,
inplanes, inplanes,
...@@ -76,7 +83,7 @@ class SparseBottleneck(Bottleneck, SparseModule): ...@@ -76,7 +83,7 @@ class SparseBottleneck(Bottleneck, SparseModule):
out = replace_feature(out, self.bn3(out.features)) out = replace_feature(out, self.bn3(out.features))
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x).features
out = replace_feature(out, out.features + identity) out = replace_feature(out, out.features + identity)
out = replace_feature(out, self.relu(out.features)) out = replace_feature(out, self.relu(out.features))
...@@ -95,6 +102,7 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -95,6 +102,7 @@ class SparseBasicBlock(BasicBlock, SparseModule):
stride (int or Tuple[int]): Stride of the first block. Defaults to 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): Down sample module for block. downsample (Module, optional): Down sample module for block.
Defaults to None. Defaults to None.
indice_key (str): Indice key for spconv. Default to None.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None. convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
...@@ -108,9 +116,15 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -108,9 +116,15 @@ class SparseBasicBlock(BasicBlock, SparseModule):
planes: int, planes: int,
stride: Union[int, Tuple[int]] = 1, stride: Union[int, Tuple[int]] = 1,
downsample: nn.Module = None, downsample: nn.Module = None,
indice_key: Optional[str] = None,
conv_cfg: OptConfigType = None, conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None) -> None: norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self) SparseModule.__init__(self)
if conv_cfg is None:
conv_cfg = dict(type='SubMConv3d')
conv_cfg.setdefault('indice_key', indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
BasicBlock.__init__( BasicBlock.__init__(
self, self,
inplanes, inplanes,
...@@ -132,7 +146,7 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -132,7 +146,7 @@ class SparseBasicBlock(BasicBlock, SparseModule):
out = replace_feature(out, self.norm2(out.features)) out = replace_feature(out, self.norm2(out.features))
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x).features
out = replace_feature(out, out.features + identity) out = replace_feature(out, out.features + identity)
out = replace_feature(out, self.relu(out.features)) out = replace_feature(out, self.relu(out.features))
...@@ -140,17 +154,16 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -140,17 +154,16 @@ class SparseBasicBlock(BasicBlock, SparseModule):
return out return out
def make_sparse_convmodule( def make_sparse_convmodule(in_channels: int,
in_channels: int,
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int]], kernel_size: Union[int, Tuple[int]],
indice_key: str, indice_key: Optional[str] = None,
stride: Union[int, Tuple[int]] = 1, stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0, padding: Union[int, Tuple[int]] = 0,
conv_type: str = 'SubMConv3d', conv_type: str = 'SubMConv3d',
norm_cfg: OptConfigType = None, norm_cfg: OptConfigType = None,
order: Tuple[str] = ('conv', 'norm', 'act') order: Tuple[str] = ('conv', 'norm', 'act'),
) -> SparseSequential: **kwargs) -> SparseSequential:
"""Make sparse convolution module. """Make sparse convolution module.
Args: Args:
...@@ -175,6 +188,8 @@ def make_sparse_convmodule( ...@@ -175,6 +188,8 @@ def make_sparse_convmodule(
assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'} assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'}
conv_cfg = dict(type=conv_type, indice_key=indice_key) conv_cfg = dict(type=conv_type, indice_key=indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
layers = list() layers = list()
for layer in order: for layer in order:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment