Commit 422d3154 authored by liyinhao's avatar liyinhao
Browse files

add scannet dataset

parent 7cd3060e
...@@ -134,7 +134,7 @@ class ScannetDataset(torch_data.Dataset): ...@@ -134,7 +134,7 @@ class ScannetDataset(torch_data.Dataset):
return input_dict return input_dict
def _get_pts_filename(self, sample_idx): def _get_pts_filename(self, sample_idx):
pts_filename = os.path.join(self.data_path, sample_idx + '_vert.npy') pts_filename = os.path.join(self.data_path, f'{sample_idx}_vert.npy')
mmcv.check_file_exist(pts_filename) mmcv.check_file_exist(pts_filename)
return pts_filename return pts_filename
...@@ -150,9 +150,9 @@ class ScannetDataset(torch_data.Dataset): ...@@ -150,9 +150,9 @@ class ScannetDataset(torch_data.Dataset):
gt_labels = np.zeros(1, ).astype(np.bool) gt_labels = np.zeros(1, ).astype(np.bool)
gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool) gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
pts_instance_mask_path = osp.join(self.data_path, pts_instance_mask_path = osp.join(self.data_path,
sample_idx + '_ins_label.npy') f'{sample_idx}_ins_label.npy')
pts_semantic_mask_path = osp.join(self.data_path, pts_semantic_mask_path = osp.join(self.data_path,
sample_idx + '_sem_label.npy') f'{sample_idx}_sem_label.npy')
anns_results = dict( anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
...@@ -167,19 +167,13 @@ class ScannetDataset(torch_data.Dataset): ...@@ -167,19 +167,13 @@ class ScannetDataset(torch_data.Dataset):
return np.random.choice(pool) return np.random.choice(pool)
def _generate_annotations(self, output): def _generate_annotations(self, output):
''' """Generate Annotations.
transfer input_dict & pred_dicts to anno format
which is needed by AP calculator Transform results of the model to the form of the evaluation.
return annos: a tuple (batch_pred_map_cls,batch_gt_map_cls)
batch_pred_map_cls is a list: i=0,1..bs-1 Args:
pred_list_i:[(pred_sem_cls, output (List): The output of the model.
box_params, box_score)_j] """
j=0,1..num_pred_obj -1
batch_gt_map_cls is a list: i=0,1..bs-1
gt_list_i: [(sem_cls_label, box_params)_j]
j=0,1..num_gt_obj -1
'''
result = [] result = []
bs = len(output) bs = len(output)
for i in range(bs): for i in range(bs):
...@@ -209,7 +203,15 @@ class ScannetDataset(torch_data.Dataset): ...@@ -209,7 +203,15 @@ class ScannetDataset(torch_data.Dataset):
results.append(result) results.append(result)
return results return results
def evaluate(self, results, metric=None, logger=None, pklfile_prefix=None): def evaluate(self, results, metric=None):
"""Evaluate.
Evaluation in indoor protocol.
Args:
results (List): List of result.
metric (dict): AP_IOU_THRESHHOLDS.
"""
results = self._format_results(results) results = self._format_results(results)
from mmdet3d.core.evaluation import indoor_eval from mmdet3d.core.evaluation import indoor_eval
assert ('AP_IOU_THRESHHOLDS' in metric) assert ('AP_IOU_THRESHHOLDS' in metric)
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmdet3d.datasets.scannet_dataset import ScannetDataset from mmdet3d.datasets import ScannetDataset
def test_getitem(): def test_getitem():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment