Commit 20c3e2c8 authored by VVsssssk's avatar VVsssssk Committed by ZwwWayne
Browse files

[Fix]Fix a bug in StackQueryAndGroup (#2043)

* fix a bug

* fix a batch inference bug

* fix docs
parent 3b0ae48c
...@@ -57,10 +57,11 @@ class StackQueryAndGroup(BaseModule): ...@@ -57,10 +57,11 @@ class StackQueryAndGroup(BaseModule):
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \ 'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)' 'str(new_xyz_batch_cnt)'
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...) # idx: (M1 + M2 ..., nsample)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums, idx = ball_query(0, self.radius, self.sample_nums, xyz, new_xyz,
xyz, new_xyz, xyz_batch_cnt, xyz_batch_cnt, new_xyz_batch_cnt)
new_xyz_batch_cnt) empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
grouped_xyz = grouping_operation( grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt, xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample) new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
......
...@@ -4,6 +4,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,6 +4,8 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import nn as nn from torch import nn as nn
...@@ -14,8 +16,6 @@ from mmdet3d.registry import MODELS, TASK_UTILS ...@@ -14,8 +16,6 @@ from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils import InstanceList from mmdet3d.utils import InstanceList
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
...@@ -440,21 +440,21 @@ class PVRCNNBBoxHead(BaseModule): ...@@ -440,21 +440,21 @@ class PVRCNNBBoxHead(BaseModule):
# post processing # post processing
result_list = [] result_list = []
for batch_id in range(batch_size): for batch_id in range(batch_size):
cls_preds = cls_preds[roi_batch_id == batch_id] cur_cls_preds = cls_preds[roi_batch_id == batch_id]
box_preds = batch_box_preds[roi_batch_id == batch_id] box_preds = batch_box_preds[roi_batch_id == batch_id]
label_preds = class_labels[batch_id] label_preds = class_labels[batch_id]
cls_preds = cls_preds.sigmoid() cur_cls_preds = cur_cls_preds.sigmoid()
cls_preds, _ = torch.max(cls_preds, dim=-1) cur_cls_preds, _ = torch.max(cur_cls_preds, dim=-1)
selected = self.class_agnostic_nms( selected = self.class_agnostic_nms(
scores=cls_preds, scores=cur_cls_preds,
bbox_preds=box_preds, bbox_preds=box_preds,
input_meta=input_metas[batch_id], input_meta=input_metas[batch_id],
nms_cfg=test_cfg) nms_cfg=test_cfg)
selected_bboxes = box_preds[selected] selected_bboxes = box_preds[selected]
selected_label_preds = label_preds[selected] selected_label_preds = label_preds[selected]
selected_scores = cls_preds[selected] selected_scores = cur_cls_preds[selected]
results = InstanceData() results = InstanceData()
results.bboxes_3d = input_metas[batch_id]['box_type_3d']( results.bboxes_3d = input_metas[batch_id]['box_type_3d'](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment