"vscode:/vscode.git/clone" did not exist on "98fa5a3b4e5a703a6c0676c2caf1880ddde821d8"
Commit 2a14897e authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix kitti evaluation bug

parent aec41c7f
......@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
where x0 < x1, y0 < y1, z0 < z1
"""
ndim = int(dims.shape[1])
corners_norm = np.stack(
np.unravel_index(np.arange(2**ndim), [2] * ndim),
axis=1).astype(dims.dtype)
corners_norm = torch.from_numpy(
np.stack(np.unravel_index(np.arange(2**ndim), [2] * ndim), axis=1)).to(
device=dims.device, dtype=dims.dtype)
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# so need to convert to a format which is convenient to do other computing.
......@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
corners_norm = corners_norm[[0, 1, 3, 2]]
elif ndim == 3:
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
corners_norm = corners_norm - np.array(origin, dtype=dims.dtype)
corners_norm = corners_norm - dims.new_tensor(origin)
corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape(
[1, 2**ndim, ndim])
return corners
......
......@@ -44,8 +44,7 @@ class KittiDataset(torch_data.Dataset):
self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0]
self.ann_file = ann_file
with open(ann_file, 'rb') as f:
self.kitti_infos = mmcv.load(f)
self.kitti_infos = mmcv.load(ann_file)
# set group flag for the sampler
if not self.test_mode:
......@@ -284,7 +283,7 @@ class KittiDataset(torch_data.Dataset):
result_files = self.bbox2result_kitti(outputs, self.class_names,
pklfile_prefix,
submission_prefix)
return result_files
return result_files, tmp_dir
def evaluate(self,
results,
......@@ -321,7 +320,7 @@ class KittiDataset(torch_data.Dataset):
if tmp_dir is not None:
tmp_dir.cleanup()
return ap_dict
return ap_dict, tmp_dir
def bbox2result_kitti(self,
net_outputs,
......@@ -332,7 +331,7 @@ class KittiDataset(torch_data.Dataset):
mmcv.mkdir_or_exist(submission_prefix)
det_annos = []
print('Converting prediction to KITTI format')
print('\nConverting prediction to KITTI format')
for idx, pred_dicts in enumerate(
mmcv.track_iter_progress(net_outputs)):
annos = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment