".github/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "ff52be33043421fd0a09c6a3d0aa342702bb281d"
Commit 2a14897e authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix kitti evaluation bug

parent aec41c7f
...@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5): ...@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
where x0 < x1, y0 < y1, z0 < z1 where x0 < x1, y0 < y1, z0 < z1
""" """
ndim = int(dims.shape[1]) ndim = int(dims.shape[1])
corners_norm = np.stack( corners_norm = torch.from_numpy(
np.unravel_index(np.arange(2**ndim), [2] * ndim), np.stack(np.unravel_index(np.arange(2**ndim), [2] * ndim), axis=1)).to(
axis=1).astype(dims.dtype) device=dims.device, dtype=dims.dtype)
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1 # now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1 # (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# so need to convert to a format which is convenient to do other computing. # so need to convert to a format which is convenient to do other computing.
...@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5): ...@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
corners_norm = corners_norm[[0, 1, 3, 2]] corners_norm = corners_norm[[0, 1, 3, 2]]
elif ndim == 3: elif ndim == 3:
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
corners_norm = corners_norm - np.array(origin, dtype=dims.dtype) corners_norm = corners_norm - dims.new_tensor(origin)
corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape( corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape(
[1, 2**ndim, ndim]) [1, 2**ndim, ndim])
return corners return corners
......
...@@ -44,8 +44,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -44,8 +44,7 @@ class KittiDataset(torch_data.Dataset):
self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0] self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0]
self.ann_file = ann_file self.ann_file = ann_file
with open(ann_file, 'rb') as f: self.kitti_infos = mmcv.load(ann_file)
self.kitti_infos = mmcv.load(f)
# set group flag for the sampler # set group flag for the sampler
if not self.test_mode: if not self.test_mode:
...@@ -284,7 +283,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -284,7 +283,7 @@ class KittiDataset(torch_data.Dataset):
result_files = self.bbox2result_kitti(outputs, self.class_names, result_files = self.bbox2result_kitti(outputs, self.class_names,
pklfile_prefix, pklfile_prefix,
submission_prefix) submission_prefix)
return result_files return result_files, tmp_dir
def evaluate(self, def evaluate(self,
results, results,
...@@ -321,7 +320,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -321,7 +320,7 @@ class KittiDataset(torch_data.Dataset):
if tmp_dir is not None: if tmp_dir is not None:
tmp_dir.cleanup() tmp_dir.cleanup()
return ap_dict return ap_dict, tmp_dir
def bbox2result_kitti(self, def bbox2result_kitti(self,
net_outputs, net_outputs,
...@@ -332,7 +331,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -332,7 +331,7 @@ class KittiDataset(torch_data.Dataset):
mmcv.mkdir_or_exist(submission_prefix) mmcv.mkdir_or_exist(submission_prefix)
det_annos = [] det_annos = []
print('Converting prediction to KITTI format') print('\nConverting prediction to KITTI format')
for idx, pred_dicts in enumerate( for idx, pred_dicts in enumerate(
mmcv.track_iter_progress(net_outputs)): mmcv.track_iter_progress(net_outputs)):
annos = [] annos = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment