Unverified Commit e90e1767 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix UT of Seg TTA & p2v_map bug when gt is none (#2466)

* fix p2v_map when no gt & unit name

* fix gt bug

* fix max_voxels
parent f8e3ce89
......@@ -9,15 +9,14 @@ model = dict(
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
max_voxels=80000),
backbone=dict(
type='MinkUNetBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
init_cfg=None),
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96]),
decode_head=dict(
type='MinkUNetHead',
channels=96,
......
......@@ -9,14 +9,14 @@ model = dict(
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
max_voxels=80000),
backbone=dict(
type='SPVCNNBackbone',
in_channels=4,
base_channels=32,
num_stages=4,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
drop_ratio=0.3),
decode_head=dict(
type='MinkUNetHead',
......
......@@ -49,6 +49,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
voxelization and dynamic voxelization. Defaults to 'hard'.
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None.
max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
......@@ -77,6 +79,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
voxel: bool = False,
voxel_type: str = 'hard',
voxel_layer: OptConfigType = None,
max_voxels: Optional[int] = None,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
......@@ -103,6 +106,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
batch_augments=batch_augments)
self.voxel = voxel
self.voxel_type = voxel_type
self.max_voxels = max_voxels
if voxel:
self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
......@@ -423,20 +427,22 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
res_coors -= res_coors.min(0)[0]
res_coors_numpy = res_coors.cpu().numpy()
inds, voxel2point_map = self.sparse_quantize(
inds, point2voxel_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
if self.training:
if len(inds) > 80000:
inds = np.random.choice(inds, 80000, replace=False)
point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
if self.training and self.max_voxels is not None:
if len(inds) > self.max_voxels:
inds = np.random.choice(
inds, self.max_voxels, replace=False)
inds = torch.from_numpy(inds).cuda()
if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.voxel2point_map = voxel2point_map.long()
data_sample.point2voxel_map = point2voxel_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
......@@ -466,12 +472,12 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
data_sample.point2voxel_map = point2voxel_map
else:
pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
data_sample.point2voxel_map = point2voxel_map
def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique().
......
......@@ -151,7 +151,7 @@ class Cylinder3DHead(Base3DDecodeHead):
for batch_idx in range(len(batch_data_samples)):
seg_logits_sample = seg_logits[coors[:, 0] == batch_idx]
point2voxel_map = batch_data_samples[
batch_idx].gt_pts_seg.point2voxel_map.long()
batch_idx].point2voxel_map.long()
point_seg_predicts = seg_logits_sample[point2voxel_map]
seg_pred_list.append(point_seg_predicts)
......
......@@ -61,7 +61,7 @@ class MinkUNetHead(Base3DDecodeHead):
seg_logit_list = []
for i, data_sample in enumerate(batch_data_samples):
seg_logit = seg_logits[batch_idx == i]
seg_logit = seg_logit[data_sample.voxel2point_map]
seg_logit = seg_logit[data_sample.point2voxel_map]
seg_logit_list.append(seg_logit)
return seg_logit_list
......
......@@ -60,8 +60,7 @@ class TestCylinder3DHead(TestCase):
self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive')
batch_inputs_dict = dict(voxels=dict(voxel_coors=coors))
datasample.gt_pts_seg.point2voxel_map = torch.randint(
0, 50, (100, )).int().cuda()
datasample.point2voxel_map = torch.randint(0, 50, (100, )).int().cuda()
point_logits = cylinder3d_head.predict(sparse_voxels,
batch_inputs_dict, [datasample])
assert point_logits[0].shape == torch.Size([100, 20])
......@@ -36,5 +36,5 @@ class TestSeg3DTTAModel(TestCase):
pcd_vertical_flip=pcd_vertical_flip_list[i]))
])
if torch.cuda.is_available():
model.eval()
model.eval().cuda()
model.test_step(dict(inputs=points, data_samples=data_samples))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment