Commit c8fa96f8 authored by liyinhao's avatar liyinhao
Browse files

change create sunrgbd and scannet related file based on comment

parent 24cde1eb
......@@ -14,17 +14,13 @@ def create_scannet_info_file(data_path, pkl_prefix='scannet', save_path=None):
assert os.path.exists(save_path)
train_filename = save_path / f'{pkl_prefix}_infos_train.pkl'
val_filename = save_path / f'{pkl_prefix}_infos_val.pkl'
dataset = ScannetObject(root_path=data_path, split='train')
train_split, val_split = 'train', 'val'
dataset.set_split(train_split)
scannet_infos_train = dataset.get_scannet_infos(has_label=True)
train_dataset = ScannetObject(root_path=data_path, split='train')
val_dataset = ScannetObject(root_path=data_path, split='val')
scannet_infos_train = train_dataset.get_scannet_infos(has_label=True)
with open(train_filename, 'wb') as f:
pickle.dump(scannet_infos_train, f)
print('Scannet info train file is saved to %s' % train_filename)
dataset.set_split(val_split)
scannet_infos_val = dataset.get_scannet_infos(has_label=True)
scannet_infos_val = val_dataset.get_scannet_infos(has_label=True)
with open(val_filename, 'wb') as f:
pickle.dump(scannet_infos_val, f)
print('Scannet info val file is saved to %s' % val_filename)
......
......@@ -10,32 +10,19 @@ class ScannetObject(object):
self.root_dir = root_path
self.split = split
self.split_dir = os.path.join(root_path)
self.type2class = {
'cabinet': 0,
'bed': 1,
'chair': 2,
'sofa': 3,
'table': 4,
'door': 5,
'window': 6,
'bookshelf': 7,
'picture': 8,
'counter': 9,
'desk': 10,
'curtain': 11,
'refrigerator': 12,
'showercurtrain': 13,
'toilet': 14,
'sink': 15,
'bathtub': 16,
'garbagebin': 17
}
self.class2type = {self.type2class[t]: t for t in self.type2class}
self.nyu40ids = np.array(
self.classes = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin'
]
self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
self.cat_ids = np.array(
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
self.nyu40id2class = {
self.cat_ids2class = {
nyu40id: i
for i, nyu40id in enumerate(list(self.nyu40ids))
for i, nyu40id in enumerate(list(self.cat_ids))
}
assert split in ['train', 'val', 'test']
split_dir = os.path.join(self.root_dir, 'meta_data',
......@@ -46,9 +33,6 @@ class ScannetObject(object):
def __len__(self):
return len(self.sample_id_list)
def set_split(self, split):
self.__init__(self.root_dir, split)
def get_box_label(self, idx):
box_file = os.path.join(self.root_dir, 'scannet_train_instance_data',
'%s_bbox.npy' % idx)
......@@ -76,7 +60,7 @@ class ScannetObject(object):
minmax_boxes3d = boxes_with_classes[:, :-1] # k, 6
classes = boxes_with_classes[:, -1] # k, 1
annotations['name'] = np.array([
self.class2type[self.nyu40id2class[classes[i]]]
self.label2cat[self.cat_ids2class[classes[i]]]
for i in range(annotations['gt_num'])
])
annotations['location'] = minmax_boxes3d[:, :3]
......@@ -85,7 +69,7 @@ class ScannetObject(object):
annotations['index'] = np.arange(
annotations['gt_num'], dtype=np.int32)
annotations['class'] = np.array([
self.nyu40id2class[classes[i]]
self.cat_ids2class[classes[i]]
for i in range(annotations['gt_num'])
])
info['annos'] = annotations
......
......@@ -17,17 +17,15 @@ def create_sunrgbd_info_file(data_path,
assert os.path.exists(save_path)
train_filename = save_path / f'{pkl_prefix}_infos_train.pkl'
val_filename = save_path / f'{pkl_prefix}_infos_val.pkl'
dataset = SUNRGBDObject(root_path=data_path, split='train', use_v1=use_v1)
train_split, val_split = 'train', 'val'
dataset.set_split(train_split)
sunrgbd_infos_train = dataset.get_sunrgbd_infos(has_label=True)
train_dataset = SUNRGBDObject(
root_path=data_path, split='train', use_v1=use_v1)
val_dataset = SUNRGBDObject(
root_path=data_path, split='val', use_v1=use_v1)
sunrgbd_infos_train = train_dataset.get_sunrgbd_infos(has_label=True)
with open(train_filename, 'wb') as f:
pickle.dump(sunrgbd_infos_train, f)
print('Sunrgbd info train file is saved to %s' % train_filename)
dataset.set_split(val_split)
sunrgbd_infos_val = dataset.get_sunrgbd_infos(has_label=True)
sunrgbd_infos_val = val_dataset.get_sunrgbd_infos(has_label=True)
with open(val_filename, 'wb') as f:
pickle.dump(sunrgbd_infos_val, f)
print('Sunrgbd info val file is saved to %s' % val_filename)
......
......@@ -50,29 +50,14 @@ class SUNRGBDObject(object):
self.root_dir = root_path
self.split = split
self.split_dir = os.path.join(root_path)
self.type2class = {
'bed': 0,
'table': 1,
'sofa': 2,
'chair': 3,
'toilet': 4,
'desk': 5,
'dresser': 6,
'night_stand': 7,
'bookshelf': 8,
'bathtub': 9
}
self.class2type = {
0: 'bed',
1: 'table',
2: 'sofa',
3: 'chair',
4: 'toilet',
5: 'desk',
6: 'dresser',
7: 'night_stand',
8: 'bookshelf',
9: 'bathtub'
self.classes = [
'bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
'night_stand', 'bookshelf', 'bathtub'
]
self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
self.label2cat = {
label: self.classes[label]
for label in len(self.classes)
}
assert split in ['train', 'val', 'test']
split_dir = os.path.join(self.root_dir, '%s_data_idx.txt' % split)
......@@ -91,9 +76,6 @@ class SUNRGBDObject(object):
def __len__(self):
return len(self.sample_id_list)
def set_split(self, split):
self.__init__(self.root_dir, split)
def get_image(self, idx):
img_filename = os.path.join(self.image_dir, '%06d.jpg' % (idx))
return cv2.imread(img_filename)
......@@ -132,6 +114,7 @@ class SUNRGBDObject(object):
# convert depth to points
SAMPLE_NUM = 50000
pc_upright_depth = self.get_depth(sample_idx)
# TODO : sample points in loading process and test
pc_upright_depth_subsampled = random_sampling(
pc_upright_depth, SAMPLE_NUM)
np.savez_compressed(
......@@ -159,41 +142,41 @@ class SUNRGBDObject(object):
annotations = {}
annotations['gt_num'] = len([
obj.classname for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
])
if annotations['gt_num'] != 0:
annotations['name'] = np.array([
obj.classname for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
])
annotations['bbox'] = np.concatenate([
obj.box2d.reshape(1, 4) for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
],
axis=0)
annotations['location'] = np.concatenate([
obj.centroid.reshape(1, 3) for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
],
axis=0)
annotations['dimensions'] = 2 * np.array([
[obj.l, obj.h, obj.w] for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
]) # lhw(depth) format
annotations['rotation_y'] = np.array([
obj.heading_angle for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
])
annotations['index'] = np.arange(
len(obj_list), dtype=np.int32)
annotations['class'] = np.array([
self.type2class[obj.classname] for obj in obj_list
if obj.classname in self.type2class.keys()
self.cat2label[obj.classname] for obj in obj_list
if obj.classname in self.cat2label.keys()
])
annotations['gt_boxes_upright_depth'] = np.stack(
[
obj.box3d for obj in obj_list
if obj.classname in self.type2class.keys()
if obj.classname in self.cat2label.keys()
],
axis=0) # (K,8)
info['annos'] = annotations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment