Commit 7193e54c authored by wuyuefeng's avatar wuyuefeng
Browse files

SparseUNet

parent 66315452
from .pillar_scatter import PointPillarsScatter from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder from .sparse_encoder import SparseEncoder
from .sparse_unet import SparseUnet from .sparse_unet import SparseUNet
__all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUnet'] __all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUNet']
...@@ -8,7 +8,7 @@ from ..registry import MIDDLE_ENCODERS ...@@ -8,7 +8,7 @@ from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module @MIDDLE_ENCODERS.register_module
class SparseUnet(nn.Module): class SparseUNet(nn.Module):
def __init__(self, def __init__(self,
in_channels, in_channels,
...@@ -24,7 +24,7 @@ class SparseUnet(nn.Module): ...@@ -24,7 +24,7 @@ class SparseUnet(nn.Module):
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16), decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)), (16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))): decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
"""SparseUnet for PartA^2 """SparseUNet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials. See https://arxiv.org/abs/1907.03670 for more detials.
...@@ -99,7 +99,7 @@ class SparseUnet(nn.Module): ...@@ -99,7 +99,7 @@ class SparseUnet(nn.Module):
) )
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUnet """Forward of SparseUNet
Args: Args:
voxel_features (torch.float32): shape [N, C] voxel_features (torch.float32): shape [N, C]
...@@ -128,8 +128,6 @@ class SparseUnet(nn.Module): ...@@ -128,8 +128,6 @@ class SparseUnet(nn.Module):
N, C, D, H, W = spatial_features.shape N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W) spatial_features = spatial_features.view(N, C * D, H, W)
ret = {'spatial_features': spatial_features}
# for segmentation head, with output shape: # for segmentation head, with output shape:
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
# [800, 704, 21] <- [400, 352, 11] # [800, 704, 21] <- [400, 352, 11]
...@@ -149,7 +147,8 @@ class SparseUnet(nn.Module): ...@@ -149,7 +147,8 @@ class SparseUnet(nn.Module):
seg_features = decode_features[-1].features seg_features = decode_features[-1].features
ret.update({'seg_features': seg_features}) ret = dict(
spatial_features=spatial_features, seg_features=seg_features)
return ret return ret
...@@ -159,7 +158,7 @@ class SparseUnet(nn.Module): ...@@ -159,7 +158,7 @@ class SparseUnet(nn.Module):
Args: Args:
x_lateral (SparseConvTensor): lateral tensor x_lateral (SparseConvTensor): lateral tensor
x_bottom (SparseConvTensor): tensor from bottom layer x_bottom (SparseConvTensor): feature from bottom layer
lateral_layer (SparseBasicBlock): convolution for lateral tensor lateral_layer (SparseBasicBlock): convolution for lateral tensor
merge_layer (SparseSequential): convolution for merging features merge_layer (SparseSequential): convolution for merging features
upsample_layer (SparseSequential): convolution for upsampling upsample_layer (SparseSequential): convolution for upsampling
...@@ -167,12 +166,11 @@ class SparseUnet(nn.Module): ...@@ -167,12 +166,11 @@ class SparseUnet(nn.Module):
Returns: Returns:
SparseConvTensor: upsampled feature SparseConvTensor: upsampled feature
""" """
x_trans = lateral_layer(x_lateral) x = lateral_layer(x_lateral)
x = x_trans x.features = torch.cat((x_bottom.features, x.features), dim=1)
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1) x_merge = merge_layer(x)
x_m = merge_layer(x) x = self.reduce_channel(x, x_merge.features.shape[1])
x = self.reduce_channel(x, x_m.features.shape[1]) x.features = x_merge.features + x.features
x.features = x_m.features + x.features
x = upsample_layer(x) x = upsample_layer(x)
return x return x
......
...@@ -4,9 +4,9 @@ import mmdet3d.ops.spconv as spconv ...@@ -4,9 +4,9 @@ import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock, SparseBasicBlockV0 from mmdet3d.ops import SparseBasicBlock, SparseBasicBlockV0
def test_SparseUnet(): def test_SparseUNet():
from mmdet3d.models.middle_encoders.sparse_unet import SparseUnet from mmdet3d.models.middle_encoders.sparse_unet import SparseUNet
self = SparseUnet( self = SparseUNet(
in_channels=4, output_shape=[41, 1600, 1408], pre_act=False) in_channels=4, output_shape=[41, 1600, 1408], pre_act=False)
# test encoder layers # test encoder layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment