Commit 7193e54c authored by wuyuefeng's avatar wuyuefeng
Browse files

SparseUNet

parent 66315452
from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder
from .sparse_unet import SparseUnet
from .sparse_unet import SparseUNet
__all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUnet']
__all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUNet']
......@@ -8,7 +8,7 @@ from ..registry import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module
class SparseUnet(nn.Module):
class SparseUNet(nn.Module):
def __init__(self,
in_channels,
......@@ -24,7 +24,7 @@ class SparseUnet(nn.Module):
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
"""SparseUnet for PartA^2
"""SparseUNet for PartA^2
See https://arxiv.org/abs/1907.03670 for more detials.
......@@ -99,7 +99,7 @@ class SparseUnet(nn.Module):
)
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUnet
"""Forward of SparseUNet
Args:
voxel_features (torch.float32): shape [N, C]
......@@ -128,8 +128,6 @@ class SparseUnet(nn.Module):
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
ret = {'spatial_features': spatial_features}
# for segmentation head, with output shape:
# [400, 352, 11] <- [200, 176, 5]
# [800, 704, 21] <- [400, 352, 11]
......@@ -149,7 +147,8 @@ class SparseUnet(nn.Module):
seg_features = decode_features[-1].features
ret.update({'seg_features': seg_features})
ret = dict(
spatial_features=spatial_features, seg_features=seg_features)
return ret
......@@ -159,7 +158,7 @@ class SparseUnet(nn.Module):
Args:
x_lateral (SparseConvTensor): lateral tensor
x_bottom (SparseConvTensor): tensor from bottom layer
x_bottom (SparseConvTensor): feature from bottom layer
lateral_layer (SparseBasicBlock): convolution for lateral tensor
merge_layer (SparseSequential): convolution for merging features
upsample_layer (SparseSequential): convolution for upsampling
......@@ -167,12 +166,11 @@ class SparseUnet(nn.Module):
Returns:
SparseConvTensor: upsampled feature
"""
x_trans = lateral_layer(x_lateral)
x = x_trans
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
x_m = merge_layer(x)
x = self.reduce_channel(x, x_m.features.shape[1])
x.features = x_m.features + x.features
x = lateral_layer(x_lateral)
x.features = torch.cat((x_bottom.features, x.features), dim=1)
x_merge = merge_layer(x)
x = self.reduce_channel(x, x_merge.features.shape[1])
x.features = x_merge.features + x.features
x = upsample_layer(x)
return x
......
......@@ -4,9 +4,9 @@ import mmdet3d.ops.spconv as spconv
from mmdet3d.ops import SparseBasicBlock, SparseBasicBlockV0
def test_SparseUnet():
from mmdet3d.models.middle_encoders.sparse_unet import SparseUnet
self = SparseUnet(
def test_SparseUNet():
from mmdet3d.models.middle_encoders.sparse_unet import SparseUNet
self = SparseUNet(
in_channels=4, output_shape=[41, 1600, 1408], pre_act=False)
# test encoder layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment