Unverified Commit 867e41a0 authored by yukang's avatar yukang Committed by GitHub
Browse files

Update spconv_backbone_focal.py

parent 57c3b507
...@@ -25,14 +25,16 @@ class SparseSequentialBatchdict(spconv.SparseSequential): ...@@ -25,14 +25,16 @@ class SparseSequentialBatchdict(spconv.SparseSequential):
super(SparseSequentialBatchdict, self).__init__(*args, **kwargs) super(SparseSequentialBatchdict, self).__init__(*args, **kwargs)
def forward(self, input, batch_dict=None): def forward(self, input, batch_dict=None):
loss = 0
for k, module in self._modules.items(): for k, module in self._modules.items():
if module is None: if module is None:
continue continue
if isinstance(module, (FocalSparseConv,)): if isinstance(module, (FocalSparseConv,)):
input, batch_dict = module(input, batch_dict) input, batch_dict, _loss = module(input, batch_dict)
loss += _loss
else: else:
input = module(input) input = module(input)
return input, batch_dict return input, batch_dict, loss
def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
...@@ -225,22 +227,20 @@ class VoxelBackBone8xFocal(nn.Module): ...@@ -225,22 +227,20 @@ class VoxelBackBone8xFocal(nn.Module):
batch_size=batch_size batch_size=batch_size
) )
batch_dict['loss_box_of_pts'] = 0 loss_img = 0
x = self.conv_input(input_sp_tensor) x = self.conv_input(input_sp_tensor)
x_conv1, batch_dict = self.conv1(x, batch_dict) x_conv1, batch_dict, loss1 = self.conv1(x, batch_dict)
loss_box_of_pts = 0
if self.use_img: if self.use_img:
x_image = self.semseg(batch_dict['images'])['layer1_feat2d'] x_image = self.semseg(batch_dict['images'])['layer1_feat2d']
x_conv1, batch_dict, loss_box_of_pts = self.conv_focal_multimodal(x_conv1, batch_dict, x_image) x_conv1, batch_dict, loss_img = self.conv_focal_multimodal(x_conv1, batch_dict, x_image)
self.forward_ret_dict['loss_box_of_pts'] = loss_box_of_pts x_conv2, batch_dict, loss2 = self.conv2(x_conv1, batch_dict)
x_conv3, batch_dict, loss3 = self.conv3(x_conv2, batch_dict)
x_conv2, batch_dict = self.conv2(x_conv1, batch_dict) x_conv4, batch_dict, loss4 = self.conv4(x_conv3, batch_dict)
x_conv3, batch_dict = self.conv3(x_conv2, batch_dict)
x_conv4, batch_dict = self.conv4(x_conv3, batch_dict)
self.forward_ret_dict['loss_box_of_pts'] = loss1 + loss2 + loss3 + loss4 + loss_img
# for detection head # for detection head
# [200, 176, 5] -> [200, 176, 2] # [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4) out = self.conv_out(x_conv4)
...@@ -267,4 +267,3 @@ class VoxelBackBone8xFocal(nn.Module): ...@@ -267,4 +267,3 @@ class VoxelBackBone8xFocal(nn.Module):
}) })
return batch_dict return batch_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment