Commit 20c3e2c8 authored by VVsssssk's avatar VVsssssk Committed by ZwwWayne
Browse files

[Fix]Fix a bug in StackQueryAndGroup (#2043)

* fix a bug

* fix a batch inference bug

* fix docs
parent 3b0ae48c
......@@ -57,10 +57,11 @@ class StackQueryAndGroup(BaseModule):
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)'
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums,
xyz, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt)
# idx: (M1 + M2 ..., nsample)
idx = ball_query(0, self.radius, self.sample_nums, xyz, new_xyz,
xyz_batch_cnt, new_xyz_batch_cnt)
empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
......
......@@ -4,6 +4,8 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from mmcv.cnn import ConvModule
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
......@@ -14,8 +16,6 @@ from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
@MODELS.register_module()
......@@ -440,21 +440,21 @@ class PVRCNNBBoxHead(BaseModule):
# post processing
result_list = []
for batch_id in range(batch_size):
cls_preds = cls_preds[roi_batch_id == batch_id]
cur_cls_preds = cls_preds[roi_batch_id == batch_id]
box_preds = batch_box_preds[roi_batch_id == batch_id]
label_preds = class_labels[batch_id]
cls_preds = cls_preds.sigmoid()
cls_preds, _ = torch.max(cls_preds, dim=-1)
cur_cls_preds = cur_cls_preds.sigmoid()
cur_cls_preds, _ = torch.max(cur_cls_preds, dim=-1)
selected = self.class_agnostic_nms(
scores=cls_preds,
scores=cur_cls_preds,
bbox_preds=box_preds,
input_meta=input_metas[batch_id],
nms_cfg=test_cfg)
selected_bboxes = box_preds[selected]
selected_label_preds = label_preds[selected]
selected_scores = cls_preds[selected]
selected_scores = cur_cls_preds[selected]
results = InstanceData()
results.bboxes_3d = input_metas[batch_id]['box_type_3d'](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment