Unverified Commit a03100ea authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support ScanNet test results submission to online benchmark (#418)

* add format_results function for ScanNet test set results submission

* support ScanNet test set data pre-processing

* replace is_train with test_mode for better consistency

* fix link error in docstring
parent e21e61e0
### Prepare ScanNet Data for Indoor Detection or Segmentation Task ### Prepare ScanNet Data for Indoor Detection or Segmentation Task
We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/). We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/).
1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. 1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. If you are performing segmentation tasks and want to upload the results to its official [benchmark](http://kaldir.vc.in.tum.de/scannet_benchmark/), please also link or move the 'scans_test' folder to this directory.
2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points. 2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points.
...@@ -26,7 +26,8 @@ scannet ...@@ -26,7 +26,8 @@ scannet
├── scannet_utils.py ├── scannet_utils.py
├── README.md ├── README.md
├── scans ├── scans
├── scannet_train_instance_data ├── scans_test
├── scannet_instance_data
├── points ├── points
│ ├── xxxxx.bin │ ├── xxxxx.bin
├── instance_mask ├── instance_mask
...@@ -40,5 +41,6 @@ scannet ...@@ -40,5 +41,6 @@ scannet
│ ├── val_resampled_scene_idxs.npy │ ├── val_resampled_scene_idxs.npy
├── scannet_infos_train.pkl ├── scannet_infos_train.pkl
├── scannet_infos_val.pkl ├── scannet_infos_val.pkl
├── scannet_infos_test.pkl
``` ```
...@@ -16,14 +16,17 @@ import os ...@@ -16,14 +16,17 @@ import os
from load_scannet_data import export from load_scannet_data import export
from os import path as osp from os import path as osp
SCANNET_DIR = 'scans'
DONOTCARE_CLASS_IDS = np.array([]) DONOTCARE_CLASS_IDS = np.array([])
OBJ_CLASS_IDS = np.array( OBJ_CLASS_IDS = np.array(
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
def export_one_scan(scan_name, output_filename_prefix, max_num_point, def export_one_scan(scan_name,
label_map_file, scannet_dir): output_filename_prefix,
max_num_point,
label_map_file,
scannet_dir,
test_mode=False):
mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply') mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply')
agg_file = osp.join(scannet_dir, scan_name, agg_file = osp.join(scannet_dir, scan_name,
scan_name + '.aggregation.json') scan_name + '.aggregation.json')
...@@ -33,8 +36,9 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point, ...@@ -33,8 +36,9 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point,
meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt') meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt')
mesh_vertices, semantic_labels, instance_labels, instance_bboxes, \ mesh_vertices, semantic_labels, instance_labels, instance_bboxes, \
instance2semantic = export(mesh_file, agg_file, seg_file, instance2semantic = export(mesh_file, agg_file, seg_file,
meta_file, label_map_file, None) meta_file, label_map_file, None, test_mode)
if not test_mode:
mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS)) mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
mesh_vertices = mesh_vertices[mask, :] mesh_vertices = mesh_vertices[mask, :]
semantic_labels = semantic_labels[mask] semantic_labels = semantic_labels[mask]
...@@ -52,23 +56,32 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point, ...@@ -52,23 +56,32 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point,
if N > max_num_point: if N > max_num_point:
choices = np.random.choice(N, max_num_point, replace=False) choices = np.random.choice(N, max_num_point, replace=False)
mesh_vertices = mesh_vertices[choices, :] mesh_vertices = mesh_vertices[choices, :]
if not test_mode:
semantic_labels = semantic_labels[choices] semantic_labels = semantic_labels[choices]
instance_labels = instance_labels[choices] instance_labels = instance_labels[choices]
np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices) np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices)
if not test_mode:
np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels) np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels) np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels)
np.save(f'{output_filename_prefix}_bbox.npy', instance_bboxes) np.save(f'{output_filename_prefix}_bbox.npy', instance_bboxes)
def batch_export(max_num_point, output_folder, train_scan_names_file, def batch_export(max_num_point,
label_map_file, scannet_dir): output_folder,
scan_names_file,
label_map_file,
scannet_dir,
test_mode=False):
if test_mode and not os.path.exists(scannet_dir):
# test data preparation is optional
return
if not os.path.exists(output_folder): if not os.path.exists(output_folder):
print(f'Creating new data folder: {output_folder}') print(f'Creating new data folder: {output_folder}')
os.mkdir(output_folder) os.mkdir(output_folder)
train_scan_names = [line.rstrip() for line in open(train_scan_names_file)] scan_names = [line.rstrip() for line in open(scan_names_file)]
for scan_name in train_scan_names: for scan_name in scan_names:
print('-' * 20 + 'begin') print('-' * 20 + 'begin')
print(datetime.datetime.now()) print(datetime.datetime.now())
print(scan_name) print(scan_name)
...@@ -79,7 +92,7 @@ def batch_export(max_num_point, output_folder, train_scan_names_file, ...@@ -79,7 +92,7 @@ def batch_export(max_num_point, output_folder, train_scan_names_file,
continue continue
try: try:
export_one_scan(scan_name, output_filename_prefix, max_num_point, export_one_scan(scan_name, output_filename_prefix, max_num_point,
label_map_file, scannet_dir) label_map_file, scannet_dir, test_mode)
except Exception: except Exception:
print(f'Failed export scan: {scan_name}') print(f'Failed export scan: {scan_name}')
print('-' * 20 + 'done') print('-' * 20 + 'done')
...@@ -93,10 +106,14 @@ def main(): ...@@ -93,10 +106,14 @@ def main():
help='The maximum number of the points.') help='The maximum number of the points.')
parser.add_argument( parser.add_argument(
'--output_folder', '--output_folder',
default='./scannet_train_instance_data', default='./scannet_instance_data',
help='output folder of the result.') help='output folder of the result.')
parser.add_argument( parser.add_argument(
'--scannet_dir', default='scans', help='scannet data directory.') '--train_scannet_dir', default='scans', help='scannet data directory.')
parser.add_argument(
'--test_scannet_dir',
default='scans_test',
help='scannet data directory.')
parser.add_argument( parser.add_argument(
'--label_map_file', '--label_map_file',
default='meta_data/scannetv2-labels.combined.tsv', default='meta_data/scannetv2-labels.combined.tsv',
...@@ -105,10 +122,25 @@ def main(): ...@@ -105,10 +122,25 @@ def main():
'--train_scan_names_file', '--train_scan_names_file',
default='meta_data/scannet_train.txt', default='meta_data/scannet_train.txt',
help='The path of the file that stores the scan names.') help='The path of the file that stores the scan names.')
parser.add_argument(
'--test_scan_names_file',
default='meta_data/scannetv2_test.txt',
help='The path of the file that stores the scan names.')
args = parser.parse_args() args = parser.parse_args()
batch_export(args.max_num_point, args.output_folder, batch_export(
args.train_scan_names_file, args.label_map_file, args.max_num_point,
args.scannet_dir) args.output_folder,
args.train_scan_names_file,
args.label_map_file,
args.train_scannet_dir,
test_mode=False)
batch_export(
args.max_num_point,
args.output_folder,
args.test_scan_names_file,
args.label_map_file,
args.test_scannet_dir,
test_mode=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -57,7 +57,8 @@ def export(mesh_file, ...@@ -57,7 +57,8 @@ def export(mesh_file,
seg_file, seg_file,
meta_file, meta_file,
label_map_file, label_map_file,
output_file=None): output_file=None,
is_train=True):
"""Export original files to vert, ins_label, sem_label and bbox file. """Export original files to vert, ins_label, sem_label and bbox file.
Args: Args:
...@@ -68,6 +69,8 @@ def export(mesh_file, ...@@ -68,6 +69,8 @@ def export(mesh_file,
label_map_file (str): Path of the label_map_file. label_map_file (str): Path of the label_map_file.
output_file (str): Path of the output folder. output_file (str): Path of the output folder.
Default: None. Default: None.
is_train (bool): Whether is generating training data with labels.
Default: True.
It returns a tuple, which containts the the following things: It returns a tuple, which containts the the following things:
np.ndarray: Vertices of points data. np.ndarray: Vertices of points data.
...@@ -83,6 +86,9 @@ def export(mesh_file, ...@@ -83,6 +86,9 @@ def export(mesh_file,
# Load scene axis alignment matrix # Load scene axis alignment matrix
lines = open(meta_file).readlines() lines = open(meta_file).readlines()
# TODO: test set data doesn't have align_matrix!
# TODO: save align_matrix and move align step to pipeline in the future
axis_align_matrix = np.eye(4)
for line in lines: for line in lines:
if 'axisAlignment' in line: if 'axisAlignment' in line:
axis_align_matrix = [ axis_align_matrix = [
...@@ -97,6 +103,7 @@ def export(mesh_file, ...@@ -97,6 +103,7 @@ def export(mesh_file,
mesh_vertices[:, 0:3] = pts[:, 0:3] mesh_vertices[:, 0:3] = pts[:, 0:3]
# Load semantic and instance labels # Load semantic and instance labels
if is_train:
object_id_to_segs, label_to_segs = read_aggregation(agg_file) object_id_to_segs, label_to_segs = read_aggregation(agg_file)
seg_to_verts, num_verts = read_segmentation(seg_file) seg_to_verts, num_verts = read_segmentation(seg_file)
label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) label_ids = np.zeros(shape=(num_verts), dtype=np.uint32)
...@@ -132,14 +139,20 @@ def export(mesh_file, ...@@ -132,14 +139,20 @@ def export(mesh_file,
zmax - zmin, label_id]) zmax - zmin, label_id])
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES # NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox instance_bboxes[obj_id - 1, :] = bbox
else:
label_ids = None
instance_ids = None
instance_bboxes = None
object_id_to_label_id = None
if output_file is not None: if output_file is not None:
np.save(output_file + '_vert.npy', mesh_vertices) np.save(output_file + '_vert.npy', mesh_vertices)
if is_train:
np.save(output_file + '_sem_label.npy', label_ids) np.save(output_file + '_sem_label.npy', label_ids)
np.save(output_file + '_ins_label.npy', instance_ids) np.save(output_file + '_ins_label.npy', instance_ids)
np.save(output_file + '_bbox.npy', instance_bboxes) np.save(output_file + '_bbox.npy', instance_bboxes)
return mesh_vertices, label_ids, instance_ids,\ return mesh_vertices, label_ids, instance_ids, \
instance_bboxes, object_id_to_label_id instance_bboxes, object_id_to_label_id
......
import numpy as np import numpy as np
import tempfile
from os import path as osp from os import path as osp
from mmdet3d.core import show_result, show_seg_result from mmdet3d.core import show_result, show_seg_result
...@@ -277,3 +278,45 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -277,3 +278,45 @@ class ScanNetSegDataset(Custom3DSegDataset):
return super().get_scene_idxs_and_label_weight(scene_idxs, return super().get_scene_idxs_and_label_weight(scene_idxs,
label_weight) label_weight)
def format_results(self, results, txtfile_prefix=None):
r"""Format the results to txt file. Refer to `ScanNet documentation
<http://kaldir.vc.in.tum.de/scannet_benchmark/documentation>`_.
Args:
outputs (list[dict]): Testing results of the dataset.
txtfile_prefix (str | None): The prefix of saved files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving submission
files when ``submission_prefix`` is not specified.
"""
import mmcv
if txtfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
txtfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
mmcv.mkdir_or_exist(txtfile_prefix)
# need to map network output to original label idx
pred2label = np.zeros(len(self.VALID_CLASS_IDS)).astype(np.int)
for original_label, output_idx in self.label_map.items():
if output_idx != self.ignore_index:
pred2label[output_idx] = original_label
outputs = []
for i, result in enumerate(results):
info = self.data_infos[i]
sample_idx = info['point_cloud']['lidar_idx']
pred_sem_mask = result['semantic_mask'].numpy().astype(np.int)
pred_label = pred2label[pred_sem_mask]
curr_file = f'{txtfile_prefix}/{sample_idx}.txt'
np.savetxt(curr_file, pred_label, fmt='%d')
outputs.append(dict(seg_mask=pred_label))
return outputs, tmp_dir
...@@ -525,3 +525,36 @@ def test_seg_show(): ...@@ -525,3 +525,36 @@ def test_seg_show():
mmcv.check_file_exist(gt_file_path) mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path) mmcv.check_file_exist(pred_file_path)
tmp_dir.cleanup() tmp_dir.cleanup()
def test_seg_format_results():
import mmcv
from os import path as osp
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScanNetSegDataset(
data_root=root_path, ann_file=ann_file, test_mode=True)
results = []
pred_sem_mask = dict(
semantic_mask=torch.tensor([
13, 5, 1, 2, 6, 2, 13, 1, 14, 2, 0, 0, 5, 5, 3, 0, 1, 14, 0, 0, 0,
18, 6, 15, 13, 0, 2, 4, 0, 3, 16, 6, 13, 5, 13, 0, 0, 0, 0, 1, 7,
3, 19, 12, 8, 0, 11, 0, 0, 1, 2, 13, 17, 1, 1, 1, 6, 2, 13, 19, 4,
17, 0, 14, 1, 7, 2, 1, 7, 2, 0, 5, 17, 5, 0, 0, 3, 6, 5, 11, 1, 13,
13, 2, 3, 1, 0, 13, 19, 1, 14, 5, 3, 1, 13, 1, 2, 3, 2, 1
]).long())
results.append(pred_sem_mask)
result_files, tmp_dir = scannet_dataset.format_results(results)
expected_label = np.array([
16, 6, 2, 3, 7, 3, 16, 2, 24, 3, 1, 1, 6, 6, 4, 1, 2, 24, 1, 1, 1, 36,
7, 28, 16, 1, 3, 5, 1, 4, 33, 7, 16, 6, 16, 1, 1, 1, 1, 2, 8, 4, 39,
14, 9, 1, 12, 1, 1, 2, 3, 16, 34, 2, 2, 2, 7, 3, 16, 39, 5, 34, 1, 24,
2, 8, 3, 2, 8, 3, 1, 6, 34, 6, 1, 1, 4, 7, 6, 12, 2, 16, 16, 3, 4, 2,
1, 16, 39, 2, 24, 6, 4, 2, 16, 2, 3, 4, 3, 2
])
expected_txt_path = osp.join(tmp_dir.name, 'results', 'scene0000_00.txt')
assert np.all(result_files[0]['seg_mask'] == expected_label)
mmcv.check_file_exist(expected_txt_path)
tmp_dir.cleanup()
...@@ -37,6 +37,8 @@ def create_indoor_info_file(data_path, ...@@ -37,6 +37,8 @@ def create_indoor_info_file(data_path,
else: else:
train_dataset = ScanNetData(root_path=data_path, split='train') train_dataset = ScanNetData(root_path=data_path, split='train')
val_dataset = ScanNetData(root_path=data_path, split='val') val_dataset = ScanNetData(root_path=data_path, split='val')
test_dataset = ScanNetData(root_path=data_path, split='test')
test_filename = os.path.join(save_path, f'{pkl_prefix}_infos_test.pkl')
infos_train = train_dataset.get_infos(num_workers=workers, has_label=True) infos_train = train_dataset.get_infos(num_workers=workers, has_label=True)
mmcv.dump(infos_train, train_filename, 'pkl') mmcv.dump(infos_train, train_filename, 'pkl')
...@@ -46,6 +48,12 @@ def create_indoor_info_file(data_path, ...@@ -46,6 +48,12 @@ def create_indoor_info_file(data_path,
mmcv.dump(infos_val, val_filename, 'pkl') mmcv.dump(infos_val, val_filename, 'pkl')
print(f'{pkl_prefix} info val file is saved to {val_filename}') print(f'{pkl_prefix} info val file is saved to {val_filename}')
if pkl_prefix == 'scannet':
infos_test = test_dataset.get_infos(
num_workers=workers, has_label=False)
mmcv.dump(infos_test, test_filename, 'pkl')
print(f'{pkl_prefix} info test file is saved to {test_filename}')
# generate infos for the semantic segmentation task # generate infos for the semantic segmentation task
# e.g. re-sampled scene indexes and label weights # e.g. re-sampled scene indexes and label weights
if pkl_prefix == 'scannet': if pkl_prefix == 'scannet':
...@@ -64,6 +72,7 @@ def create_indoor_info_file(data_path, ...@@ -64,6 +72,7 @@ def create_indoor_info_file(data_path,
split='val', split='val',
num_points=8192, num_points=8192,
label_weight_func=lambda x: 1.0 / np.log(1.2 + x)) label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
# no need to generate for test set
train_dataset.get_seg_infos() train_dataset.get_seg_infos()
val_dataset.get_seg_infos() val_dataset.get_seg_infos()
...@@ -37,12 +37,13 @@ class ScanNetData(object): ...@@ -37,12 +37,13 @@ class ScanNetData(object):
f'scannetv2_{split}.txt') f'scannetv2_{split}.txt')
mmcv.check_file_exist(split_file) mmcv.check_file_exist(split_file)
self.sample_id_list = mmcv.list_from_file(split_file) self.sample_id_list = mmcv.list_from_file(split_file)
self.test_mode = (split == 'test')
def __len__(self): def __len__(self):
return len(self.sample_id_list) return len(self.sample_id_list)
def get_box_label(self, idx): def get_box_label(self, idx):
box_file = osp.join(self.root_dir, 'scannet_train_instance_data', box_file = osp.join(self.root_dir, 'scannet_instance_data',
f'{idx}_bbox.npy') f'{idx}_bbox.npy')
mmcv.check_file_exist(box_file) mmcv.check_file_exist(box_file)
return np.load(box_file) return np.load(box_file)
...@@ -67,36 +68,41 @@ class ScanNetData(object): ...@@ -67,36 +68,41 @@ class ScanNetData(object):
info = dict() info = dict()
pc_info = {'num_features': 6, 'lidar_idx': sample_idx} pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
info['point_cloud'] = pc_info info['point_cloud'] = pc_info
pts_filename = osp.join(self.root_dir, pts_filename = osp.join(self.root_dir, 'scannet_instance_data',
'scannet_train_instance_data',
f'{sample_idx}_vert.npy') f'{sample_idx}_vert.npy')
pts_instance_mask_path = osp.join(self.root_dir, points = np.load(pts_filename)
'scannet_train_instance_data', mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points'))
points.tofile(
osp.join(self.root_dir, 'points', f'{sample_idx}.bin'))
info['pts_path'] = osp.join('points', f'{sample_idx}.bin')
if not self.test_mode:
pts_instance_mask_path = osp.join(
self.root_dir, 'scannet_instance_data',
f'{sample_idx}_ins_label.npy') f'{sample_idx}_ins_label.npy')
pts_semantic_mask_path = osp.join(self.root_dir, pts_semantic_mask_path = osp.join(
'scannet_train_instance_data', self.root_dir, 'scannet_instance_data',
f'{sample_idx}_sem_label.npy') f'{sample_idx}_sem_label.npy')
points = np.load(pts_filename) pts_instance_mask = np.load(pts_instance_mask_path).astype(
pts_instance_mask = np.load(pts_instance_mask_path).astype(np.long) np.long)
pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.long) pts_semantic_mask = np.load(pts_semantic_mask_path).astype(
np.long)
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points'))
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask')) mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask'))
mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask')) mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask'))
points.tofile(
osp.join(self.root_dir, 'points', f'{sample_idx}.bin'))
pts_instance_mask.tofile( pts_instance_mask.tofile(
osp.join(self.root_dir, 'instance_mask', f'{sample_idx}.bin')) osp.join(self.root_dir, 'instance_mask',
f'{sample_idx}.bin'))
pts_semantic_mask.tofile( pts_semantic_mask.tofile(
osp.join(self.root_dir, 'semantic_mask', f'{sample_idx}.bin')) osp.join(self.root_dir, 'semantic_mask',
f'{sample_idx}.bin'))
info['pts_path'] = osp.join('points', f'{sample_idx}.bin') info['pts_instance_mask_path'] = osp.join(
info['pts_instance_mask_path'] = osp.join('instance_mask', 'instance_mask', f'{sample_idx}.bin')
f'{sample_idx}.bin') info['pts_semantic_mask_path'] = osp.join(
info['pts_semantic_mask_path'] = osp.join('semantic_mask', 'semantic_mask', f'{sample_idx}.bin')
f'{sample_idx}.bin')
if has_label: if has_label:
annotations = {} annotations = {}
...@@ -150,6 +156,7 @@ class ScanNetSegData(object): ...@@ -150,6 +156,7 @@ class ScanNetSegData(object):
self.data_root = data_root self.data_root = data_root
self.data_infos = mmcv.load(ann_file) self.data_infos = mmcv.load(ann_file)
self.split = split self.split = split
assert split in ['train', 'val', 'test']
self.num_points = num_points self.num_points = num_points
self.all_ids = np.arange(41) # all possible ids self.all_ids = np.arange(41) # all possible ids
...@@ -170,6 +177,8 @@ class ScanNetSegData(object): ...@@ -170,6 +177,8 @@ class ScanNetSegData(object):
label_weight_func is None else label_weight_func label_weight_func is None else label_weight_func
def get_seg_infos(self): def get_seg_infos(self):
if self.split == 'test':
return
scene_idxs, label_weight = self.get_scene_idxs_and_label_weight() scene_idxs, label_weight = self.get_scene_idxs_and_label_weight()
save_folder = osp.join(self.data_root, 'seg_info') save_folder = osp.join(self.data_root, 'seg_info')
mmcv.mkdir_or_exist(save_folder) mmcv.mkdir_or_exist(save_folder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment