Commit e6f6151c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

add RETURN_ENCODED_TENSOR config for Spconv UNet

parent b6fbc50d
...@@ -91,8 +91,8 @@ class UNetV2(nn.Module): ...@@ -91,8 +91,8 @@ class UNetV2(nn.Module):
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
) )
last_pad = 0 if self.model_cfg.get('RETURN_ENCODED_TENSOR', True):
last_pad = self.model_cfg.get('last_pad', last_pad) last_pad = self.model_cfg.get('last_pad', 0)
self.conv_out = spconv.SparseSequential( self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2] # [200, 150, 5] -> [200, 150, 2]
...@@ -101,6 +101,8 @@ class UNetV2(nn.Module): ...@@ -101,6 +101,8 @@ class UNetV2(nn.Module):
norm_fn(128), norm_fn(128),
nn.ReLU(), nn.ReLU(),
) )
else:
self.conv_out = None
# decoder # decoder
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
...@@ -181,9 +183,12 @@ class UNetV2(nn.Module): ...@@ -181,9 +183,12 @@ class UNetV2(nn.Module):
x_conv3 = self.conv3(x_conv2) x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3) x_conv4 = self.conv4(x_conv3)
if self.conv_out is not None:
# for detection head # for detection head
# [200, 176, 5] -> [200, 176, 2] # [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4) out = self.conv_out(x_conv4)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
# for segmentation head # for segmentation head
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
...@@ -201,6 +206,4 @@ class UNetV2(nn.Module): ...@@ -201,6 +206,4 @@ class UNetV2(nn.Module):
point_cloud_range=self.point_cloud_range point_cloud_range=self.point_cloud_range
) )
batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1) batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
return batch_dict return batch_dict
...@@ -12,6 +12,7 @@ MODEL: ...@@ -12,6 +12,7 @@ MODEL:
BACKBONE_3D: BACKBONE_3D:
NAME: UNetV2 NAME: UNetV2
RETURN_ENCODED_TENSOR: False
POINT_HEAD: POINT_HEAD:
NAME: PointHeadBox NAME: PointHeadBox
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment