"git@developer.sourcefind.cn:change/sglang.git" did not exist on "76c48a0913b993bb185f4c4c2d884b28689e0007"
Unverified Commit b6571d22 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix spconv block (#2531)

parent 22aaa47f
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
......@@ -35,6 +35,7 @@ class SparseBottleneck(Bottleneck, SparseModule):
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): Down sample module for block.
Defaults to None.
indice_key (str): Indice key for spconv. Default to None.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
......@@ -48,10 +49,16 @@ class SparseBottleneck(Bottleneck, SparseModule):
planes: int,
stride: Union[int, Tuple[int]] = 1,
downsample: nn.Module = None,
indice_key=None,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self)
if conv_cfg is None:
conv_cfg = dict(type='SubMConv3d')
conv_cfg.setdefault('indice_key', indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
Bottleneck.__init__(
self,
inplanes,
......@@ -76,7 +83,7 @@ class SparseBottleneck(Bottleneck, SparseModule):
out = replace_feature(out, self.bn3(out.features))
if self.downsample is not None:
identity = self.downsample(x)
identity = self.downsample(x).features
out = replace_feature(out, out.features + identity)
out = replace_feature(out, self.relu(out.features))
......@@ -95,6 +102,7 @@ class SparseBasicBlock(BasicBlock, SparseModule):
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): Down sample module for block.
Defaults to None.
indice_key (str): Indice key for spconv. Default to None.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
......@@ -108,9 +116,15 @@ class SparseBasicBlock(BasicBlock, SparseModule):
planes: int,
stride: Union[int, Tuple[int]] = 1,
downsample: nn.Module = None,
indice_key: Optional[str] = None,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self)
if conv_cfg is None:
conv_cfg = dict(type='SubMConv3d')
conv_cfg.setdefault('indice_key', indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
BasicBlock.__init__(
self,
inplanes,
......@@ -132,7 +146,7 @@ class SparseBasicBlock(BasicBlock, SparseModule):
out = replace_feature(out, self.norm2(out.features))
if self.downsample is not None:
identity = self.downsample(x)
identity = self.downsample(x).features
out = replace_feature(out, out.features + identity)
out = replace_feature(out, self.relu(out.features))
......@@ -140,17 +154,16 @@ class SparseBasicBlock(BasicBlock, SparseModule):
return out
def make_sparse_convmodule(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
indice_key: str,
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
conv_type: str = 'SubMConv3d',
norm_cfg: OptConfigType = None,
order: Tuple[str] = ('conv', 'norm', 'act')
) -> SparseSequential:
def make_sparse_convmodule(in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
indice_key: Optional[str] = None,
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
conv_type: str = 'SubMConv3d',
norm_cfg: OptConfigType = None,
order: Tuple[str] = ('conv', 'norm', 'act'),
**kwargs) -> SparseSequential:
"""Make sparse convolution module.
Args:
......@@ -175,6 +188,8 @@ def make_sparse_convmodule(
assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'}
conv_cfg = dict(type=conv_type, indice_key=indice_key)
if norm_cfg is None:
norm_cfg = dict(type='BN1d')
layers = list()
for layer in order:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment