Unverified Commit 62ce67c0 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

[Fix]: fix label type bug when using dbsampler (#111)

* [Fix]: fix label type bug when using dbsampler

* Unify cat_id for more general usage

* fix CI bugs

* keep astype np.long
parent ee801168
......@@ -93,7 +93,7 @@ jobs:
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5
- name: Upload coverage to Codecov
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}}
if: ${{matrix.torch == '1.5.0+cu101' && matrix.python-version == '3.7'}}
uses: codecov/codecov-action@v1.0.10
with:
file: ./coverage.xml
......
......@@ -57,6 +57,7 @@ class Custom3DDataset(Dataset):
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.CLASSES = self.get_classes(classes)
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None:
......@@ -300,7 +301,7 @@ class Custom3DDataset(Dataset):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus
are all zeros.
otherwise group 0. In 3D datasets, they are all the same, thus are all
zeros.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
......@@ -18,6 +18,7 @@ class CBGSDataset(object):
def __init__(self, dataset):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.sample_indices = self._get_sample_indices()
# self.dataset.data_infos = self.data_infos
if hasattr(self.dataset, 'flag'):
......@@ -34,22 +35,23 @@ class CBGSDataset(object):
Returns:
list[dict]: List of annotations after class sampling.
"""
class_sample_idxs = {name: [] for name in self.CLASSES}
class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
for idx in range(len(self.dataset)):
class_sample_idx = self.dataset.get_cat_ids(idx)
for key in class_sample_idxs.keys():
class_sample_idxs[key] += class_sample_idx[key]
duplicated_samples = sum([len(v) for _, v in class_sample_idx.items()])
sample_cat_ids = self.dataset.get_cat_ids(idx)
for cat_id in sample_cat_ids:
class_sample_idxs[cat_id].append(idx)
duplicated_samples = sum(
[len(v) for _, v in class_sample_idxs.items()])
class_distribution = {
k: len(v) / duplicated_samples
for k, v in class_sample_idx.items()
for k, v in class_sample_idxs.items()
}
sample_indices = []
frac = 1.0 / len(self.CLASSES)
ratios = [frac / v for v in class_distribution.values()]
for cls_inds, ratio in zip(list(class_sample_idx.values()), ratios):
for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
sample_indices += np.random.choice(cls_inds,
int(len(cls_inds) *
ratio)).tolist()
......
......@@ -150,17 +150,18 @@ class NuScenesDataset(Custom3DDataset):
contains such boxes, store a list containing idx,
otherwise, store empty list.
"""
class_sample_idx = {name: [] for name in self.CLASSES}
info = self.data_infos[idx]
if self.use_valid_flag:
mask = info['valid_flag']
gt_names = set(info['gt_names'][mask])
else:
gt_names = set(info['gt_names'])
cat_ids = []
for name in gt_names:
if name in self.CLASSES:
class_sample_idx[name].append(idx)
return class_sample_idx
cat_ids.append(self.cat2id[name])
return cat_ids
def load_annotations(self, ann_file):
"""Load annotations from ann_file.
......
......@@ -259,9 +259,9 @@ class DataBaseSampler(object):
count += 1
s_points_list.append(s_points)
# gt_names = np.array([s['name'] for s in sampled]),
# gt_labels = np.array([self.cat2label(s) for s in gt_names])
gt_labels = np.array([self.cat2label[s['name']] for s in sampled])
gt_labels = np.array([self.cat2label[s['name']] for s in sampled],
dtype=np.long)
ret = {
'gt_labels_3d':
gt_labels,
......
......@@ -198,7 +198,7 @@ class ObjectSample(object):
input_dict['img'] = sampled_dict['img']
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d
input_dict['gt_labels_3d'] = gt_labels_3d.astype(np.long)
input_dict['points'] = points
return input_dict
......
......@@ -62,13 +62,13 @@ def test_getitem():
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR'))
nus_dataset = build_dataset(dataset_cfg)
assert len(nus_dataset) == 10
assert len(nus_dataset) == 20
data = nus_dataset[0]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5)
data = nus_dataset[1]
assert data['img_metas'].data['flip'] is False
assert data['img_metas'].data['pcd_horizontal_flip'] is False
assert data['points']._data.shape == (901, 5)
data = nus_dataset[1]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5)
......@@ -78,6 +78,7 @@ def test_object_sample():
gt_labels.append(CLASSES.index(cat))
else:
gt_labels.append(-1)
gt_labels = np.array(gt_labels, dtype=np.long)
input_dict = dict(
points=points, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels)
input_dict = object_sample(input_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment