"vscode:/vscode.git/clone" did not exist on "d07ae9c4c4ada5ce80ffebb8f17f4ad4f13077bd"
Unverified Commit 32567b04 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Merge pull request #192 from sshaoshuai/master

Release OpenPCDet v0.3.0
parents 853b759b 04e0d4f0
from .base_bev_backbone import BaseBEVBackbone from .base_bev_backbone import BaseBEVBackbone
__all__ = { __all__ = {
'BaseBEVBackbone': BaseBEVBackbone 'BaseBEVBackbone': BaseBEVBackbone
} }
\ No newline at end of file
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -7,13 +8,20 @@ class BaseBEVBackbone(nn.Module): ...@@ -7,13 +8,20 @@ class BaseBEVBackbone(nn.Module):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS) if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS) assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS num_filters = self.model_cfg.NUM_FILTERS
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS else:
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
num_levels = len(layer_nums) num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]] c_in_list = [input_channels, *num_filters[:-1]]
...@@ -37,15 +45,28 @@ class BaseBEVBackbone(nn.Module): ...@@ -37,15 +45,28 @@ class BaseBEVBackbone(nn.Module):
]) ])
self.blocks.append(nn.Sequential(*cur_layers)) self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0: if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential( stride = upsample_strides[idx]
nn.ConvTranspose2d( if stride >= 1:
num_filters[idx], num_upsample_filters[idx], self.deblocks.append(nn.Sequential(
upsample_strides[idx], nn.ConvTranspose2d(
stride=upsample_strides[idx], bias=False num_filters[idx], num_upsample_filters[idx],
), upsample_strides[idx],
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01), stride=upsample_strides[idx], bias=False
nn.ReLU() ),
)) nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters) c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels: if len(upsample_strides) > num_levels:
......
from .spconv_backbone import VoxelBackBone8x from .pointnet2_backbone import PointNet2Backbone, PointNet2MSG
from .spconv_backbone import VoxelBackBone8x, VoxelResBackBone8x
from .spconv_unet import UNetV2 from .spconv_unet import UNetV2
__all__ = { __all__ = {
'VoxelBackBone8x': VoxelBackBone8x, 'VoxelBackBone8x': VoxelBackBone8x,
'UNetV2': UNetV2 'UNetV2': UNetV2,
'PointNet2Backbone': PointNet2Backbone,
'PointNet2MSG': PointNet2MSG,
'VoxelResBackBone8x': VoxelResBackBone8x,
} }
from .voxel_set_abstraction import VoxelSetAbstraction from .voxel_set_abstraction import VoxelSetAbstraction
__all__ = { __all__ = {
'VoxelSetAbstraction': VoxelSetAbstraction 'VoxelSetAbstraction': VoxelSetAbstraction
} }
import torch import torch
import torch.nn as nn import torch.nn as nn
from ....utils import common_utils
from ....ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules from ....ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules
from ....ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_stack_utils from ....ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_stack_utils
from ....utils import common_utils
def bilinear_interpolate_torch(im, x, y): def bilinear_interpolate_torch(im, x, y):
...@@ -236,4 +237,3 @@ class VoxelSetAbstraction(nn.Module): ...@@ -236,4 +237,3 @@ class VoxelSetAbstraction(nn.Module):
batch_dict['point_features'] = point_features # (BxN, C) batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4) batch_dict['point_coords'] = point_coords # (BxN, 4)
return batch_dict return batch_dict
import torch
import torch.nn as nn
from ...ops.pointnet2.pointnet2_batch import pointnet2_modules
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_modules_stack
from ...ops.pointnet2.pointnet2_stack import pointnet2_utils as pointnet2_utils_stack
class PointNet2MSG(nn.Module):
def __init__(self, model_cfg, input_channels, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
channel_in = input_channels - 3
self.num_points_each_layer = []
skip_channel_list = [input_channels - 3]
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules.PointnetSAModuleMSG(
npoint=self.model_cfg.SA_CONFIG.NPOINTS[k],
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True),
)
)
skip_channel_list.append(channel_out)
channel_in = channel_out
self.FP_modules = nn.ModuleList()
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules.PointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
self.num_point_features = self.model_cfg.FP_MLPS[0][-1]
def break_up_pc(self, pc):
batch_idx = pc[:, 0]
xyz = pc[:, 1:4].contiguous()
features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None)
return batch_idx, xyz, features
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points']
batch_idx, xyz, features = self.break_up_pc(points)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
assert xyz_batch_cnt.min() == xyz_batch_cnt.max()
xyz = xyz.view(batch_size, -1, 3)
features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1) if features is not None else None
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
) # (B, C, N)
point_features = l_features[0].permute(0, 2, 1).contiguous() # (B, N, C)
batch_dict['point_features'] = point_features.view(-1, point_features.shape[-1])
batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0].view(-1, 3)), dim=1)
return batch_dict
class PointNet2Backbone(nn.Module):
"""
DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723
"""
def __init__(self, model_cfg, input_channels, **kwargs):
assert False, 'DO NOT USE THIS CURRENTLY SINCE IT MAY HAVE POTENTIAL BUGS, 20200723'
super().__init__()
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
channel_in = input_channels - 3
self.num_points_each_layer = []
skip_channel_list = [input_channels]
for k in range(self.model_cfg.SA_CONFIG.NPOINTS.__len__()):
self.num_points_each_layer.append(self.model_cfg.SA_CONFIG.NPOINTS[k])
mlps = self.model_cfg.SA_CONFIG.MLPS[k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
pointnet2_modules_stack.StackSAModuleMSG(
radii=self.model_cfg.SA_CONFIG.RADIUS[k],
nsamples=self.model_cfg.SA_CONFIG.NSAMPLE[k],
mlps=mlps,
use_xyz=self.model_cfg.SA_CONFIG.get('USE_XYZ', True),
)
)
skip_channel_list.append(channel_out)
channel_in = channel_out
self.FP_modules = nn.ModuleList()
for k in range(self.model_cfg.FP_MLPS.__len__()):
pre_channel = self.model_cfg.FP_MLPS[k + 1][-1] if k + 1 < len(self.model_cfg.FP_MLPS) else channel_out
self.FP_modules.append(
pointnet2_modules_stack.StackPointnetFPModule(
mlp=[pre_channel + skip_channel_list[k]] + self.model_cfg.FP_MLPS[k]
)
)
self.num_point_features = self.model_cfg.FP_MLPS[0][-1]
def break_up_pc(self, pc):
batch_idx = pc[:, 0]
xyz = pc[:, 1:4].contiguous()
features = (pc[:, 4:].contiguous() if pc.size(-1) > 4 else None)
return batch_idx, xyz, features
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
batch_size = batch_dict['batch_size']
points = batch_dict['points']
batch_idx, xyz, features = self.break_up_pc(points)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (batch_idx == bs_idx).sum()
l_xyz, l_features, l_batch_cnt = [xyz], [features], [xyz_batch_cnt]
for i in range(len(self.SA_modules)):
new_xyz_list = []
for k in range(batch_size):
if len(l_xyz) == 1:
cur_xyz = l_xyz[0][batch_idx == k]
else:
last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]:
empty_num = self.num_points_each_layer[i] - cur_xyz.shape[1]
cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num]
new_xyz_list.append(cur_xyz[cur_pt_idxs])
new_xyz = torch.cat(new_xyz_list, dim=0)
new_xyz_batch_cnt = xyz.new_zeros(batch_size).int().fill_(self.num_points_each_layer[i])
li_xyz, li_features = self.SA_modules[i](
xyz=l_xyz[i], features=l_features[i], xyz_batch_cnt=l_batch_cnt[i],
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
l_xyz.append(li_xyz)
l_features.append(li_features)
l_batch_cnt.append(new_xyz_batch_cnt)
l_features[0] = points[:, 1:]
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_features[i - 1] = self.FP_modules[i](
unknown=l_xyz[i - 1], unknown_batch_cnt=l_batch_cnt[i - 1],
known=l_xyz[i], known_batch_cnt=l_batch_cnt[i],
unknown_feats=l_features[i - 1], known_feats=l_features[i]
)
batch_dict['point_features'] = l_features[0]
batch_dict['point_coords'] = torch.cat((batch_idx[:, None].float(), l_xyz[0]), dim=1)
return batch_dict
import torch.nn as nn
import spconv
from functools import partial from functools import partial
import spconv
import torch.nn as nn
def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
conv_type='subm', norm_fn=None): conv_type='subm', norm_fn=None):
...@@ -25,6 +26,45 @@ def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stri ...@@ -25,6 +26,45 @@ def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stri
return m return m
class SparseBasicBlock(spconv.SparseModule):
expansion = 1
def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None):
super(SparseBasicBlock, self).__init__()
assert norm_fn is not None
bias = norm_fn is not None
self.conv1 = spconv.SubMConv3d(
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
)
self.bn1 = norm_fn(planes)
self.relu = nn.ReLU()
self.conv2 = spconv.SubMConv3d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
)
self.bn2 = norm_fn(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out.features = self.bn1(out.features)
out.features = self.relu(out.features)
out = self.conv2(out)
out.features = self.bn2(out.features)
if self.downsample is not None:
identity = self.downsample(x)
out.features += identity.features
out.features = self.relu(out.features)
return out
class VoxelBackBone8x(nn.Module): class VoxelBackBone8x(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs): def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__() super().__init__()
...@@ -121,3 +161,101 @@ class VoxelBackBone8x(nn.Module): ...@@ -121,3 +161,101 @@ class VoxelBackBone8x(nn.Module):
}) })
return batch_dict return batch_dict
class VoxelResBackBone8x(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__()
self.model_cfg = model_cfg
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.sparse_shape = grid_size[::-1] + [1, 0, 0]
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
norm_fn(16),
nn.ReLU(),
)
block = post_act_block
self.conv1 = spconv.SparseSequential(
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'),
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'),
)
self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] <- [800, 704, 21]
block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'),
)
self.conv3 = spconv.SparseSequential(
# [800, 704, 21] <- [400, 352, 11]
block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'),
)
self.conv4 = spconv.SparseSequential(
# [400, 352, 11] <- [200, 176, 5]
block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'),
)
last_pad = 0
last_pad = self.model_cfg.get('last_pad', last_pad)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(128, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'),
norm_fn(128),
nn.ReLU(),
)
self.num_point_features = 128
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
"""
voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
batch_size = batch_dict['batch_size']
input_sp_tensor = spconv.SparseConvTensor(
features=voxel_features,
indices=voxel_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
batch_dict.update({
'encoded_spconv_tensor': out,
'encoded_spconv_tensor_stride': 8
})
batch_dict.update({
'multi_scale_3d_features': {
'x_conv1': x_conv1,
'x_conv2': x_conv2,
'x_conv3': x_conv3,
'x_conv4': x_conv4,
}
})
return batch_dict
from functools import partial
import spconv
import torch import torch
import torch.nn as nn import torch.nn as nn
import spconv
from functools import partial
from .spconv_backbone import post_act_block
from ...utils import common_utils from ...utils import common_utils
from .spconv_backbone import post_act_block
class SparseBasicBlock(spconv.SparseModule): class SparseBasicBlock(spconv.SparseModule):
...@@ -91,16 +93,18 @@ class UNetV2(nn.Module): ...@@ -91,16 +93,18 @@ class UNetV2(nn.Module):
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
) )
last_pad = 0 if self.model_cfg.get('RETURN_ENCODED_TENSOR', True):
last_pad = self.model_cfg.get('last_pad', last_pad) last_pad = self.model_cfg.get('last_pad', 0)
self.conv_out = spconv.SparseSequential( self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2] # [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'), bias=False, indice_key='spconv_down2'),
norm_fn(128), norm_fn(128),
nn.ReLU(), nn.ReLU(),
) )
else:
self.conv_out = None
# decoder # decoder
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
...@@ -181,9 +185,12 @@ class UNetV2(nn.Module): ...@@ -181,9 +185,12 @@ class UNetV2(nn.Module):
x_conv3 = self.conv3(x_conv2) x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3) x_conv4 = self.conv4(x_conv3)
# for detection head if self.conv_out is not None:
# [200, 176, 5] -> [200, 176, 2] # for detection head
out = self.conv_out(x_conv4) # [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
# for segmentation head # for segmentation head
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
...@@ -201,6 +208,4 @@ class UNetV2(nn.Module): ...@@ -201,6 +208,4 @@ class UNetV2(nn.Module):
point_cloud_range=self.point_cloud_range point_cloud_range=self.point_cloud_range
) )
batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1) batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
return batch_dict return batch_dict
from .vfe_template import VFETemplate
from .mean_vfe import MeanVFE from .mean_vfe import MeanVFE
from .pillar_vfe import PillarVFE from .pillar_vfe import PillarVFE
from .vfe_template import VFETemplate
__all__ = { __all__ = {
'VFETemplate': VFETemplate, 'VFETemplate': VFETemplate,
......
import torch import torch
from .vfe_template import VFETemplate from .vfe_template import VFETemplate
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .vfe_template import VFETemplate from .vfe_template import VFETemplate
......
from .anchor_head_template import AnchorHeadTemplate from .anchor_head_multi import AnchorHeadMulti
from .anchor_head_single import AnchorHeadSingle from .anchor_head_single import AnchorHeadSingle
from .point_intra_part_head import PointIntraPartOffsetHead from .anchor_head_template import AnchorHeadTemplate
from .point_head_box import PointHeadBox
from .point_head_simple import PointHeadSimple from .point_head_simple import PointHeadSimple
from .anchor_head_multi import AnchorHeadMulti from .point_intra_part_head import PointIntraPartOffsetHead
__all__ = { __all__ = {
'AnchorHeadTemplate': AnchorHeadTemplate, 'AnchorHeadTemplate': AnchorHeadTemplate,
'AnchorHeadSingle': AnchorHeadSingle, 'AnchorHeadSingle': AnchorHeadSingle,
'PointIntraPartOffsetHead': PointIntraPartOffsetHead, 'PointIntraPartOffsetHead': PointIntraPartOffsetHead,
'PointHeadSimple': PointHeadSimple, 'PointHeadSimple': PointHeadSimple,
'PointHeadBox': PointHeadBox,
'AnchorHeadMulti': AnchorHeadMulti, 'AnchorHeadMulti': AnchorHeadMulti,
} }
import numpy as np import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone from ..backbones_2d import BaseBEVBackbone
import torch from .anchor_head_template import AnchorHeadTemplate
class SingleHead(BaseBEVBackbone): class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None): def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, rpn_head_cfg=None,
super().__init__(encode_conv_cfg, input_channels) head_label_indices=None, separate_reg_config=None):
super().__init__(rpn_head_cfg, input_channels)
self.num_anchors_per_location = num_anchors_per_location self.num_anchors_per_location = num_anchors_per_location
self.num_class = num_class self.num_class = num_class
self.code_size = code_size self.code_size = code_size
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.separate_reg_config = separate_reg_config
self.register_buffer('head_label_indices', head_label_indices)
self.conv_cls = nn.Conv2d( if self.separate_reg_config is not None:
input_channels, self.num_anchors_per_location * self.num_class, code_size_cnt = 0
kernel_size=1 self.conv_box = nn.ModuleDict()
) self.conv_box_names = []
self.conv_box = nn.Conv2d( num_middle_conv = self.separate_reg_config.NUM_MIDDLE_CONV
input_channels, self.num_anchors_per_location * self.code_size, num_middle_filter = self.separate_reg_config.NUM_MIDDLE_FILTER
kernel_size=1 conv_cls_list = []
) c_in = input_channels
for k in range(num_middle_conv):
conv_cls_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
conv_cls_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * self.num_class,
kernel_size=3, stride=1, padding=1
))
self.conv_cls = nn.Sequential(*conv_cls_list)
for reg_config in self.separate_reg_config.REG_LIST:
reg_name, reg_channel = reg_config.split(':')
reg_channel = int(reg_channel)
cur_conv_list = []
c_in = input_channels
for k in range(num_middle_conv):
cur_conv_list.extend([
nn.Conv2d(
c_in, num_middle_filter,
kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(num_middle_filter),
nn.ReLU()
])
c_in = num_middle_filter
cur_conv_list.append(nn.Conv2d(
c_in, self.num_anchors_per_location * int(reg_channel),
kernel_size=3, stride=1, padding=1, bias=True
))
code_size_cnt += reg_channel
self.conv_box[f'conv_{reg_name}'] = nn.Sequential(*cur_conv_list)
self.conv_box_names.append(f'conv_{reg_name}')
for m in self.conv_box.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
assert code_size_cnt == code_size, f'Code size does not match: {code_size_cnt}:{code_size}'
else:
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
self.conv_box = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.code_size,
kernel_size=1
)
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None: if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
self.conv_dir_cls = nn.Conv2d( self.conv_dir_cls = nn.Conv2d(
...@@ -31,19 +91,29 @@ class SingleHead(BaseBEVBackbone): ...@@ -31,19 +91,29 @@ class SingleHead(BaseBEVBackbone):
) )
else: else:
self.conv_dir_cls = None self.conv_dir_cls = None
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False) self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
pi = 0.01 pi = 0.01
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi)) if isinstance(self.conv_cls, nn.Conv2d):
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
else:
nn.init.constant_(self.conv_cls[-1].bias, -np.log((1 - pi) / pi))
def forward(self, spatial_features_2d): def forward(self, spatial_features_2d):
ret_dict = {} ret_dict = {}
spatial_features_2d = super().forward({'spatial_features': spatial_features_2d})['spatial_features_2d'] spatial_features_2d = super().forward({'spatial_features': spatial_features_2d})['spatial_features_2d']
cls_preds = self.conv_cls(spatial_features_2d) cls_preds = self.conv_cls(spatial_features_2d)
box_preds = self.conv_box(spatial_features_2d)
if self.separate_reg_config is None:
box_preds = self.conv_box(spatial_features_2d)
else:
box_preds_list = []
for reg_name in self.conv_box_names:
box_preds_list.append(self.conv_box[reg_name](spatial_features_2d))
box_preds = torch.cat(box_preds_list, dim=1)
if not self.use_multihead: if not self.use_multihead:
box_preds = box_preds.permute(0, 2, 3, 1).contiguous() box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
...@@ -56,13 +126,14 @@ class SingleHead(BaseBEVBackbone): ...@@ -56,13 +126,14 @@ class SingleHead(BaseBEVBackbone):
cls_preds = cls_preds.view(-1, self.num_anchors_per_location, cls_preds = cls_preds.view(-1, self.num_anchors_per_location,
self.num_class, H, W).permute(0, 1, 3, 4, 2).contiguous() self.num_class, H, W).permute(0, 1, 3, 4, 2).contiguous()
box_preds = box_preds.view(batch_size, -1, self.code_size) box_preds = box_preds.view(batch_size, -1, self.code_size)
cls_preds = cls_preds.view(batch_size, -1, self.num_class).unsqueeze(-1) cls_preds = cls_preds.view(batch_size, -1, self.num_class)
if self.conv_dir_cls is not None: if self.conv_dir_cls is not None:
dir_cls_preds = self.conv_dir_cls(spatial_features_2d) dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
if self.use_multihead: if self.use_multihead:
dir_cls_preds = dir_cls_preds.view( dir_cls_preds = dir_cls_preds.view(
-1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4, 2).contiguous() -1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4,
2).contiguous()
dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS) dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
else: else:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
...@@ -78,12 +149,27 @@ class SingleHead(BaseBEVBackbone): ...@@ -78,12 +149,27 @@ class SingleHead(BaseBEVBackbone):
class AnchorHeadMulti(AnchorHeadTemplate): class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True): def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True):
super().__init__( super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size,
point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
) )
self.model_cfg = model_cfg self.model_cfg = model_cfg
self.make_multihead(input_channels) self.separate_multihead = self.model_cfg.get('SEPARATE_MULTIHEAD', False)
if self.model_cfg.get('SHARED_CONV_NUM_FILTER', None) is not None:
shared_conv_num_filter = self.model_cfg.SHARED_CONV_NUM_FILTER
self.shared_conv = nn.Sequential(
nn.Conv2d(input_channels, shared_conv_num_filter, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(shared_conv_num_filter, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
else:
self.shared_conv = None
shared_conv_num_filter = input_channels
self.rpn_heads = None
self.make_multihead(shared_conv_num_filter)
def make_multihead(self, input_channels): def make_multihead(self, input_channels):
rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
...@@ -91,34 +177,46 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -91,34 +177,46 @@ class AnchorHeadMulti(AnchorHeadTemplate):
class_names = [] class_names = []
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
class_names.extend(rpn_head_cfg['HEAD_CLS_NAME']) class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
for rpn_head_cfg in rpn_head_cfgs: for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['HEAD_CLS_NAME']]) num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)]
rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg) for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
head_label_indices = torch.from_numpy(np.array([
self.class_names.index(cur_name) + 1 for cur_name in rpn_head_cfg['HEAD_CLS_NAME']
]))
rpn_head = SingleHead(
self.model_cfg, input_channels,
len(rpn_head_cfg['HEAD_CLS_NAME']) if self.separate_multihead else self.num_class,
num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg,
head_label_indices=head_label_indices,
separate_reg_config=self.model_cfg.get('SEPARATE_REG_CONFIG', None)
)
rpn_heads.append(rpn_head) rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads) self.rpn_heads = nn.ModuleList(rpn_heads)
def forward(self, data_dict): def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d'] spatial_features_2d = data_dict['spatial_features_2d']
if self.shared_conv is not None:
spatial_features_2d = self.shared_conv(spatial_features_2d)
ret_dicts = [] ret_dicts = []
for rpn_head in self.rpn_heads: for rpn_head in self.rpn_heads:
ret_dicts.append(rpn_head(spatial_features_2d)) ret_dicts.append(rpn_head(spatial_features_2d))
cls_preds = torch.cat([ret_dict['cls_preds'] for ret_dict in ret_dicts], dim=1) cls_preds = [ret_dict['cls_preds'] for ret_dict in ret_dicts]
box_preds = torch.cat([ret_dict['box_preds'] for ret_dict in ret_dicts], dim=1) box_preds = [ret_dict['box_preds'] for ret_dict in ret_dicts]
ret = { ret = {
'cls_preds': cls_preds, 'cls_preds': cls_preds if self.separate_multihead else torch.cat(cls_preds, dim=1),
'box_preds': box_preds, 'box_preds': box_preds if self.separate_multihead else torch.cat(box_preds, dim=1),
} }
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False): if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
dir_cls_preds = torch.cat([ret_dict['dir_cls_preds'] for ret_dict in ret_dicts], dim=1) dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts]
ret['dir_cls_preds'] = dir_cls_preds ret['dir_cls_preds'] = dir_cls_preds if self.separate_multihead else torch.cat(dir_cls_preds, dim=1)
else:
dir_cls_preds = None
self.forward_ret_dict.update(ret) self.forward_ret_dict.update(ret)
if self.training: if self.training:
targets_dict = self.assign_targets( targets_dict = self.assign_targets(
gt_boxes=data_dict['gt_boxes'] gt_boxes=data_dict['gt_boxes']
...@@ -128,10 +226,148 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -128,10 +226,148 @@ class AnchorHeadMulti(AnchorHeadTemplate):
if not self.training or self.predict_boxes_when_training: if not self.training or self.predict_boxes_when_training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes( batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'], batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds cls_preds=ret['cls_preds'], box_preds=ret['box_preds'], dir_cls_preds=ret.get('dir_cls_preds', None)
) )
if isinstance(batch_cls_preds, list):
multihead_label_mapping = []
for idx in range(len(batch_cls_preds)):
multihead_label_mapping.append(self.rpn_heads[idx].head_label_indices)
data_dict['multihead_label_mapping'] = multihead_label_mapping
data_dict['batch_cls_preds'] = batch_cls_preds data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds data_dict['batch_box_preds'] = batch_box_preds
data_dict['cls_preds_normalized'] = False data_dict['cls_preds_normalized'] = False
return data_dict return data_dict
def get_cls_layer_loss(self):
loss_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
if 'pos_cls_weight' in loss_weights:
pos_cls_weight = loss_weights['pos_cls_weight']
neg_cls_weight = loss_weights['neg_cls_weight']
else:
pos_cls_weight = neg_cls_weight = 1.0
cls_preds = self.forward_ret_dict['cls_preds']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
batch_size = int(cls_preds[0].shape[0])
cared = box_cls_labels >= 0 # [N, num_anchors]
positives = box_cls_labels > 0
negatives = box_cls_labels == 0
negative_cls_weights = negatives * 1.0 * neg_cls_weight
cls_weights = (negative_cls_weights + pos_cls_weight * positives).float()
reg_weights = positives.float()
if self.num_class == 1:
# class agnostic
box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
one_hot_targets = torch.zeros(
*list(cls_targets.shape), self.num_class + 1, dtype=cls_preds[0].dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., 1:]
start_idx = c_idx = 0
cls_losses = 0
for idx, cls_pred in enumerate(cls_preds):
cur_num_class = self.rpn_heads[idx].num_class
cls_pred = cls_pred.view(batch_size, -1, cur_num_class)
if self.separate_multihead:
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1],
c_idx:c_idx + cur_num_class]
c_idx += cur_num_class
else:
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1]]
cls_weight = cls_weights[:, start_idx:start_idx + cls_pred.shape[1]]
cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * loss_weights['cls_weight']
cls_losses += cls_loss
start_idx += cls_pred.shape[1]
assert start_idx == one_hot_targets.shape[1]
tb_dict = {
'rpn_loss_cls': cls_losses.item()
}
return cls_losses, tb_dict
def get_box_reg_layer_loss(self):
box_preds = self.forward_ret_dict['box_preds']
box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None)
box_reg_targets = self.forward_ret_dict['box_reg_targets']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
positives = box_cls_labels > 0
reg_weights = positives.float()
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
if not isinstance(box_preds, list):
box_preds = [box_preds]
batch_size = int(box_preds[0].shape[0])
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat(
[anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1])
for anchor in self.anchors], dim=0
)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
start_idx = 0
box_losses = 0
tb_dict = {}
for idx, box_pred in enumerate(box_preds):
box_pred = box_pred.view(
batch_size, -1,
box_pred.shape[-1] // self.num_anchors_per_location if not self.use_multihead else box_pred.shape[-1]
)
box_reg_target = box_reg_targets[:, start_idx:start_idx + box_pred.shape[1]]
reg_weight = reg_weights[:, start_idx:start_idx + box_pred.shape[1]]
# sin(a - b) = sinacosb-cosasinb
if box_dir_cls_preds is not None:
box_pred_sin, reg_target_sin = self.add_sin_difference(box_pred, box_reg_target)
loc_loss_src = self.reg_loss_func(box_pred_sin, reg_target_sin, weights=reg_weight) # [N, M]
else:
loc_loss_src = self.reg_loss_func(box_pred, box_reg_target, weights=reg_weight) # [N, M]
loc_loss = loc_loss_src.sum() / batch_size
loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
box_losses += loc_loss
tb_dict['rpn_loss_loc'] = tb_dict.get('rpn_loss_loc', 0) + loc_loss.item()
if box_dir_cls_preds is not None:
if not isinstance(box_dir_cls_preds, list):
box_dir_cls_preds = [box_dir_cls_preds]
dir_targets = self.get_direction_target(
anchors, box_reg_targets,
dir_offset=self.model_cfg.DIR_OFFSET,
num_bins=self.model_cfg.NUM_DIR_BINS
)
box_dir_cls_pred = box_dir_cls_preds[idx]
dir_logit = box_dir_cls_pred.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
weights = positives.type_as(dir_logit)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
weight = weights[:, start_idx:start_idx + box_pred.shape[1]]
dir_target = dir_targets[:, start_idx:start_idx + box_pred.shape[1]]
dir_loss = self.dir_loss_func(dir_logit, dir_target, weights=weight)
dir_loss = dir_loss.sum() / batch_size
dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
box_losses += dir_loss
tb_dict['rpn_loss_dir'] = tb_dict.get('rpn_loss_dir', 0) + dir_loss.item()
start_idx += box_pred.shape[1]
return box_losses, tb_dict
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate from .anchor_head_template import AnchorHeadTemplate
...@@ -72,4 +73,3 @@ class AnchorHeadSingle(AnchorHeadTemplate): ...@@ -72,4 +73,3 @@ class AnchorHeadSingle(AnchorHeadTemplate):
data_dict['cls_preds_normalized'] = False data_dict['cls_preds_normalized'] = False
return data_dict return data_dict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import box_coder_utils, common_utils, loss_utils
from .target_assigner.anchor_generator import AnchorGenerator from .target_assigner.anchor_generator import AnchorGenerator
from .target_assigner.atss_target_assigner import ATSSTargetAssigner from .target_assigner.atss_target_assigner import ATSSTargetAssigner
from .target_assigner.axis_aligned_target_assigner import AxisAlignedTargetAssigner from .target_assigner.axis_aligned_target_assigner import AxisAlignedTargetAssigner
from ...utils import box_coder_utils, loss_utils, common_utils
class AnchorHeadTemplate(nn.Module): class AnchorHeadTemplate(nn.Module):
...@@ -14,44 +15,53 @@ class AnchorHeadTemplate(nn.Module): ...@@ -14,44 +15,53 @@ class AnchorHeadTemplate(nn.Module):
self.num_class = num_class self.num_class = num_class
self.class_names = class_names self.class_names = class_names
self.predict_boxes_when_training = predict_boxes_when_training self.predict_boxes_when_training = predict_boxes_when_training
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False) self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)( self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6) num_dir_bins=anchor_target_cfg.get('NUM_DIR_BINS', 6),
**anchor_target_cfg.get('BOX_CODER_CONFIG', {})
) )
anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG anchor_generator_cfg = self.model_cfg.ANCHOR_GENERATOR_CONFIG
anchors, self.num_anchors_per_location = self.generate_anchors( anchors, self.num_anchors_per_location = self.generate_anchors(
anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range anchor_generator_cfg, grid_size=grid_size, point_cloud_range=point_cloud_range,
anchor_ndim=self.box_coder.code_size
) )
self.anchors = [x.cuda() for x in anchors] self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg) self.target_assigner = self.get_target_assigner(anchor_target_cfg)
self.forward_ret_dict = {} self.forward_ret_dict = {}
self.build_losses(self.model_cfg.LOSS_CONFIG) self.build_losses(self.model_cfg.LOSS_CONFIG)
@staticmethod @staticmethod
def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range): def generate_anchors(anchor_generator_cfg, grid_size, point_cloud_range, anchor_ndim=7):
anchor_generator = AnchorGenerator( anchor_generator = AnchorGenerator(
anchor_range=point_cloud_range, anchor_range=point_cloud_range,
anchor_generator_config=anchor_generator_cfg anchor_generator_config=anchor_generator_cfg
) )
feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg] feature_map_size = [grid_size[:2] // config['feature_map_stride'] for config in anchor_generator_cfg]
anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size) anchors_list, num_anchors_per_location_list = anchor_generator.generate_anchors(feature_map_size)
if anchor_ndim != 7:
for idx, anchors in enumerate(anchors_list):
pad_zeros = anchors.new_zeros([*anchors.shape[0:-1], anchor_ndim - 7])
new_anchors = torch.cat((anchors, pad_zeros), dim=-1)
anchors_list[idx] = new_anchors
return anchors_list, num_anchors_per_location_list return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg): def get_target_assigner(self, anchor_target_cfg):
if anchor_target_cfg.NAME == 'ATSS': if anchor_target_cfg.NAME == 'ATSS':
target_assigner = ATSSTargetAssigner( target_assigner = ATSSTargetAssigner(
topk=anchor_target_cfg.TOPK, topk=anchor_target_cfg.TOPK,
box_coder=self.box_coder, box_coder=self.box_coder,
use_multihead=self.use_multihead,
match_height=anchor_target_cfg.MATCH_HEIGHT match_height=anchor_target_cfg.MATCH_HEIGHT
) )
elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner': elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner':
target_assigner = AxisAlignedTargetAssigner( target_assigner = AxisAlignedTargetAssigner(
anchor_target_cfg=anchor_target_cfg, model_cfg=self.model_cfg,
anchor_generator_cfg=anchor_generator_cfg,
class_names=self.class_names, class_names=self.class_names,
box_coder=self.box_coder, box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT match_height=anchor_target_cfg.MATCH_HEIGHT
...@@ -65,9 +75,11 @@ class AnchorHeadTemplate(nn.Module): ...@@ -65,9 +75,11 @@ class AnchorHeadTemplate(nn.Module):
'cls_loss_func', 'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
) )
reg_loss_name = 'WeightedSmoothL1Loss' if losses_cfg.get('REG_LOSS_TYPE', None) is None \
else losses_cfg.REG_LOSS_TYPE
self.add_module( self.add_module(
'reg_loss_func', 'reg_loss_func',
loss_utils.WeightedSmoothL1Loss(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights']) getattr(loss_utils, reg_loss_name)(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
) )
self.add_module( self.add_module(
'dir_loss_func', 'dir_loss_func',
...@@ -82,7 +94,7 @@ class AnchorHeadTemplate(nn.Module): ...@@ -82,7 +94,7 @@ class AnchorHeadTemplate(nn.Module):
""" """
targets_dict = self.target_assigner.assign_targets( targets_dict = self.target_assigner.assign_targets(
self.anchors, gt_boxes, self.use_multihead self.anchors, gt_boxes
) )
return targets_dict return targets_dict
...@@ -113,8 +125,6 @@ class AnchorHeadTemplate(nn.Module): ...@@ -113,8 +125,6 @@ class AnchorHeadTemplate(nn.Module):
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0) one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class) cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:] one_hot_targets = one_hot_targets[..., 1:]
# import pdb
# pdb.set_trace()
cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M] cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size cls_loss = cls_loss_src.sum() / batch_size
...@@ -235,14 +245,17 @@ class AnchorHeadTemplate(nn.Module): ...@@ -235,14 +245,17 @@ class AnchorHeadTemplate(nn.Module):
anchors = self.anchors anchors = self.anchors
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0] num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1) batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() \
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(cls_preds, list) else cls_preds
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) \
else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors) batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)
if dir_cls_preds is not None: if dir_cls_preds is not None:
dir_offset = self.model_cfg.DIR_OFFSET dir_offset = self.model_cfg.DIR_OFFSET
dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) \
else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)
dir_labels = torch.max(dir_cls_preds, dim=-1)[1] dir_labels = torch.max(dir_cls_preds, dim=-1)[1]
period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS) period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)
......
import torch
from ...utils import box_coder_utils, box_utils
from .point_head_template import PointHeadTemplate
class PointHeadBox(PointHeadTemplate):
"""
A simple point-based segmentation head, which are used for PointRCNN.
Reference Paper: https://arxiv.org/abs/1812.04244
PointRCNN: 3D Object Proposal Generation and Detection from Point Cloud
"""
def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels,
output_channels=num_class
)
target_cfg = self.model_cfg.TARGET_CONFIG
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
def assign_targets(self, input_dict):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords = input_dict['point_coords']
gt_boxes = input_dict['gt_boxes']
assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)
batch_size = gt_boxes.shape[0]
extend_gt_boxes = box_utils.enlarge_box3d(
gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
).view(batch_size, -1, gt_boxes.shape[-1])
targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=False, ret_box_labels=True
)
return targets_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss()
point_loss_box, tb_dict_2 = self.get_box_layer_loss()
point_loss = point_loss_cls + point_loss_box
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2)
return point_loss, tb_dict
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_features_before_fusion: (N1 + N2 + N3 + ..., C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
point_features = batch_dict['point_features_before_fusion']
else:
point_features = batch_dict['point_features']
point_cls_preds = self.cls_layers(point_features) # (total_points, num_class)
point_box_preds = self.box_layers(point_features) # (total_points, box_code_size)
point_cls_preds_max, _ = point_cls_preds.max(dim=-1)
batch_dict['point_cls_scores'] = torch.sigmoid(point_cls_preds_max)
ret_dict = {'point_cls_preds': point_cls_preds,
'point_box_preds': point_box_preds}
if self.training:
targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_box_labels'] = targets_dict['point_box_labels']
if not self.training or self.predict_boxes_when_training:
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=point_box_preds
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict
import torch import torch
from .point_head_template import PointHeadTemplate
from ...utils import box_utils from ...utils import box_utils
from .point_head_template import PointHeadTemplate
class PointHeadSimple(PointHeadTemplate): class PointHeadSimple(PointHeadTemplate):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...utils import loss_utils, common_utils
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import common_utils, loss_utils
class PointHeadTemplate(nn.Module): class PointHeadTemplate(nn.Module):
...@@ -19,7 +20,17 @@ class PointHeadTemplate(nn.Module): ...@@ -19,7 +20,17 @@ class PointHeadTemplate(nn.Module):
'cls_loss_func', 'cls_loss_func',
loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0) loss_utils.SigmoidFocalClassificationLoss(alpha=0.25, gamma=2.0)
) )
self.reg_loss_func = F.smooth_l1_loss if losses_cfg.get('LOSS_REG', None) == 'smooth-l1' else F.l1_loss reg_loss_type = losses_cfg.get('LOSS_REG', None)
if reg_loss_type == 'smooth-l1':
self.reg_loss_func = F.smooth_l1_loss
elif reg_loss_type == 'l1':
self.reg_loss_func = F.l1_loss
elif reg_loss_type == 'WeightedSmoothL1Loss':
self.reg_loss_func = loss_utils.WeightedSmoothL1Loss(
code_weights=losses_cfg.LOSS_WEIGHTS.get('code_weights', None)
)
else:
self.reg_loss_func = F.smooth_l1_loss
@staticmethod @staticmethod
def make_fc_layers(fc_cfg, input_channels, output_channels): def make_fc_layers(fc_cfg, input_channels, output_channels):
...@@ -88,11 +99,15 @@ class PointHeadTemplate(nn.Module): ...@@ -88,11 +99,15 @@ class PointHeadTemplate(nn.Module):
raise NotImplementedError raise NotImplementedError
gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]] gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, 7].long() point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
point_cls_labels[bs_mask] = point_cls_labels_single point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels: if ret_box_labels:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8)) point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(points_single[fg_flag], gt_box_of_fg_points) fg_point_box_labels = self.box_coder.encode_torch(
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
gt_classes=gt_box_of_fg_points[:, -1].long()
)
point_box_labels_single[fg_flag] = fg_point_box_labels point_box_labels_single[fg_flag] = fg_point_box_labels
point_box_labels[bs_mask] = point_box_labels_single point_box_labels[bs_mask] = point_box_labels_single
...@@ -113,7 +128,7 @@ class PointHeadTemplate(nn.Module): ...@@ -113,7 +128,7 @@ class PointHeadTemplate(nn.Module):
} }
return targets_dict return targets_dict
def get_cls_layer_loss(self): def get_cls_layer_loss(self, tb_dict=None):
point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1) point_cls_labels = self.forward_ret_dict['point_cls_labels'].view(-1)
point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class) point_cls_preds = self.forward_ret_dict['point_cls_preds'].view(-1, self.num_class)
...@@ -131,13 +146,15 @@ class PointHeadTemplate(nn.Module): ...@@ -131,13 +146,15 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight'] point_loss_cls = point_loss_cls * loss_weights_dict['point_cls_weight']
tb_dict = { if tb_dict is None:
tb_dict = {}
tb_dict.update({
'point_loss_cls': point_loss_cls.item(), 'point_loss_cls': point_loss_cls.item(),
'point_pos_num': pos_normalizer.item() 'point_pos_num': pos_normalizer.item()
} })
return point_loss_cls, tb_dict return point_loss_cls, tb_dict
def get_part_layer_loss(self): def get_part_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0 pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
pos_normalizer = max(1, (pos_mask > 0).sum().item()) pos_normalizer = max(1, (pos_mask > 0).sum().item())
point_part_labels = self.forward_ret_dict['point_part_labels'] point_part_labels = self.forward_ret_dict['point_part_labels']
...@@ -147,7 +164,47 @@ class PointHeadTemplate(nn.Module): ...@@ -147,7 +164,47 @@ class PointHeadTemplate(nn.Module):
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight'] point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']
return point_loss_part, {'point_loss_part': point_loss_part.item()} if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_part': point_loss_part.item()})
return point_loss_part, tb_dict
def get_box_layer_loss(self, tb_dict=None):
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
point_box_labels = self.forward_ret_dict['point_box_labels']
point_box_preds = self.forward_ret_dict['point_box_preds']
reg_weights = pos_mask.float()
pos_normalizer = pos_mask.sum().float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
point_loss_box_src = self.reg_loss_func(
point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
)
point_loss_box = point_loss_box_src.sum()
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
point_loss_box = point_loss_box * loss_weights_dict['point_box_weight']
if tb_dict is None:
tb_dict = {}
tb_dict.update({'point_loss_box': point_loss_box.item()})
return point_loss_box, tb_dict
def generate_predicted_boxes(self, points, point_cls_preds, point_box_preds):
"""
Args:
points: (N, 3)
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
Returns:
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
"""
_, pred_classes = point_cls_preds.max(dim=-1)
point_box_preds = self.box_coder.decode_torch(point_box_preds, points, pred_classes + 1)
return point_cls_preds, point_box_preds
def forward(self, **kwargs): def forward(self, **kwargs):
raise NotImplementedError raise NotImplementedError
import torch import torch
from ...utils import box_coder_utils, box_utils
from .point_head_template import PointHeadTemplate from .point_head_template import PointHeadTemplate
from ...utils import box_utils
class PointIntraPartOffsetHead(PointHeadTemplate): class PointIntraPartOffsetHead(PointHeadTemplate):
...@@ -9,8 +10,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -9,8 +10,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
Reference Paper: https://arxiv.org/abs/1907.03670 Reference Paper: https://arxiv.org/abs/1907.03670
From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
""" """
def __init__(self, num_class, input_channels, model_cfg, **kwargs): def __init__(self, num_class, input_channels, model_cfg, predict_boxes_when_training=False, **kwargs):
super().__init__(model_cfg=model_cfg, num_class=num_class) super().__init__(model_cfg=model_cfg, num_class=num_class)
self.predict_boxes_when_training = predict_boxes_when_training
self.cls_layers = self.make_fc_layers( self.cls_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.CLS_FC, fc_cfg=self.model_cfg.CLS_FC,
input_channels=input_channels, input_channels=input_channels,
...@@ -21,6 +23,18 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -21,6 +23,18 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
input_channels=input_channels, input_channels=input_channels,
output_channels=3 output_channels=3
) )
target_cfg = self.model_cfg.TARGET_CONFIG
if target_cfg.get('BOX_CODER', None) is not None:
self.box_coder = getattr(box_coder_utils, target_cfg.BOX_CODER)(
**target_cfg.BOX_CODER_CONFIG
)
self.box_layers = self.make_fc_layers(
fc_cfg=self.model_cfg.REG_FC,
input_channels=input_channels,
output_channels=self.box_coder.code_size
)
else:
self.box_layers = None
def assign_targets(self, input_dict): def assign_targets(self, input_dict):
""" """
...@@ -46,19 +60,20 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -46,19 +60,20 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_stack_targets( targets_dict = self.assign_stack_targets(
points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes, points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
set_ignore_flag=True, use_ball_constraint=False, set_ignore_flag=True, use_ball_constraint=False,
ret_part_labels=True ret_part_labels=True, ret_box_labels=(self.box_layers is not None)
) )
return targets_dict return targets_dict
def get_loss(self, tb_dict=None): def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict tb_dict = {} if tb_dict is None else tb_dict
point_loss_cls, tb_dict_1 = self.get_cls_layer_loss() point_loss_cls, tb_dict = self.get_cls_layer_loss(tb_dict)
point_loss_part, tb_dict_2 = self.get_part_layer_loss() point_loss_part, tb_dict = self.get_part_layer_loss(tb_dict)
point_loss = point_loss_cls + point_loss_part point_loss = point_loss_cls + point_loss_part
tb_dict.update(tb_dict_1)
tb_dict.update(tb_dict_2) if self.box_layers is not None:
point_loss_box, tb_dict = self.get_box_layer_loss(tb_dict)
point_loss += point_loss_box
return point_loss, tb_dict return point_loss, tb_dict
def forward(self, batch_dict): def forward(self, batch_dict):
...@@ -83,6 +98,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -83,6 +98,9 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
'point_cls_preds': point_cls_preds, 'point_cls_preds': point_cls_preds,
'point_part_preds': point_part_preds, 'point_part_preds': point_part_preds,
} }
if self.box_layers is not None:
point_box_preds = self.box_layers(point_features)
ret_dict['point_box_preds'] = point_box_preds
point_cls_scores = torch.sigmoid(point_cls_preds) point_cls_scores = torch.sigmoid(point_cls_preds)
point_part_offset = torch.sigmoid(point_part_preds) point_part_offset = torch.sigmoid(point_part_preds)
...@@ -93,6 +111,17 @@ class PointIntraPartOffsetHead(PointHeadTemplate): ...@@ -93,6 +111,17 @@ class PointIntraPartOffsetHead(PointHeadTemplate):
targets_dict = self.assign_targets(batch_dict) targets_dict = self.assign_targets(batch_dict)
ret_dict['point_cls_labels'] = targets_dict['point_cls_labels'] ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
ret_dict['point_part_labels'] = targets_dict.get('point_part_labels') ret_dict['point_part_labels'] = targets_dict.get('point_part_labels')
self.forward_ret_dict = ret_dict ret_dict['point_box_labels'] = targets_dict.get('point_box_labels')
if self.box_layers is not None and (not self.training or self.predict_boxes_when_training):
point_cls_preds, point_box_preds = self.generate_predicted_boxes(
points=batch_dict['point_coords'][:, 1:4],
point_cls_preds=point_cls_preds, point_box_preds=ret_dict['point_box_preds']
)
batch_dict['batch_cls_preds'] = point_cls_preds
batch_dict['batch_box_preds'] = point_box_preds
batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
batch_dict['cls_preds_normalized'] = False
self.forward_ret_dict = ret_dict
return batch_dict return batch_dict
import torch import torch
from ....utils import common_utils
from ....ops.iou3d_nms import iou3d_nms_utils from ....ops.iou3d_nms import iou3d_nms_utils
from ....utils import common_utils
class ATSSTargetAssigner(object): class ATSSTargetAssigner(object):
...@@ -28,8 +29,8 @@ class ATSSTargetAssigner(object): ...@@ -28,8 +29,8 @@ class ATSSTargetAssigner(object):
cls_labels_list, reg_targets_list, reg_weights_list = [], [], [] cls_labels_list, reg_targets_list, reg_weights_list = [], [], []
for anchors in anchors_list: for anchors in anchors_list:
batch_size = gt_boxes_with_classes.shape[0] batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7] gt_classes = gt_boxes_with_classes[:, :, -1]
gt_boxes = gt_boxes_with_classes[:, :, :7] gt_boxes = gt_boxes_with_classes[:, :, :-1]
if use_multihead: if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1]) anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment