Unverified Commit 62ce67c0 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

[Fix]: fix label type bug when using dbsampler (#111)

* [Fix]: fix label type bug when using dbsampler

* Unify cat_id for more general usage

* fix CI bugs

* keep astype np.long
parent ee801168
...@@ -93,7 +93,7 @@ jobs: ...@@ -93,7 +93,7 @@ jobs:
coverage report -m coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5 # Only upload coverage report for python3.7 && pytorch1.5
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}} if: ${{matrix.torch == '1.5.0+cu101' && matrix.python-version == '3.7'}}
uses: codecov/codecov-action@v1.0.10 uses: codecov/codecov-action@v1.0.10
with: with:
file: ./coverage.xml file: ./coverage.xml
......
...@@ -57,6 +57,7 @@ class Custom3DDataset(Dataset): ...@@ -57,6 +57,7 @@ class Custom3DDataset(Dataset):
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d) self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
self.CLASSES = self.get_classes(classes) self.CLASSES = self.get_classes(classes)
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.data_infos = self.load_annotations(self.ann_file) self.data_infos = self.load_annotations(self.ann_file)
if pipeline is not None: if pipeline is not None:
...@@ -300,7 +301,7 @@ class Custom3DDataset(Dataset): ...@@ -300,7 +301,7 @@ class Custom3DDataset(Dataset):
"""Set flag according to image aspect ratio. """Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1, Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0. In 3D datasets, they are all the same, thus otherwise group 0. In 3D datasets, they are all the same, thus are all
are all zeros. zeros.
""" """
self.flag = np.zeros(len(self), dtype=np.uint8) self.flag = np.zeros(len(self), dtype=np.uint8)
...@@ -18,6 +18,7 @@ class CBGSDataset(object): ...@@ -18,6 +18,7 @@ class CBGSDataset(object):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
self.CLASSES = dataset.CLASSES self.CLASSES = dataset.CLASSES
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.sample_indices = self._get_sample_indices() self.sample_indices = self._get_sample_indices()
# self.dataset.data_infos = self.data_infos # self.dataset.data_infos = self.data_infos
if hasattr(self.dataset, 'flag'): if hasattr(self.dataset, 'flag'):
...@@ -34,22 +35,23 @@ class CBGSDataset(object): ...@@ -34,22 +35,23 @@ class CBGSDataset(object):
Returns: Returns:
list[dict]: List of annotations after class sampling. list[dict]: List of annotations after class sampling.
""" """
class_sample_idxs = {name: [] for name in self.CLASSES} class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
for idx in range(len(self.dataset)): for idx in range(len(self.dataset)):
class_sample_idx = self.dataset.get_cat_ids(idx) sample_cat_ids = self.dataset.get_cat_ids(idx)
for key in class_sample_idxs.keys(): for cat_id in sample_cat_ids:
class_sample_idxs[key] += class_sample_idx[key] class_sample_idxs[cat_id].append(idx)
duplicated_samples = sum([len(v) for _, v in class_sample_idx.items()]) duplicated_samples = sum(
[len(v) for _, v in class_sample_idxs.items()])
class_distribution = { class_distribution = {
k: len(v) / duplicated_samples k: len(v) / duplicated_samples
for k, v in class_sample_idx.items() for k, v in class_sample_idxs.items()
} }
sample_indices = [] sample_indices = []
frac = 1.0 / len(self.CLASSES) frac = 1.0 / len(self.CLASSES)
ratios = [frac / v for v in class_distribution.values()] ratios = [frac / v for v in class_distribution.values()]
for cls_inds, ratio in zip(list(class_sample_idx.values()), ratios): for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
sample_indices += np.random.choice(cls_inds, sample_indices += np.random.choice(cls_inds,
int(len(cls_inds) * int(len(cls_inds) *
ratio)).tolist() ratio)).tolist()
......
...@@ -150,17 +150,18 @@ class NuScenesDataset(Custom3DDataset): ...@@ -150,17 +150,18 @@ class NuScenesDataset(Custom3DDataset):
contains such boxes, store a list containing idx, contains such boxes, store a list containing idx,
otherwise, store empty list. otherwise, store empty list.
""" """
class_sample_idx = {name: [] for name in self.CLASSES}
info = self.data_infos[idx] info = self.data_infos[idx]
if self.use_valid_flag: if self.use_valid_flag:
mask = info['valid_flag'] mask = info['valid_flag']
gt_names = set(info['gt_names'][mask]) gt_names = set(info['gt_names'][mask])
else: else:
gt_names = set(info['gt_names']) gt_names = set(info['gt_names'])
cat_ids = []
for name in gt_names: for name in gt_names:
if name in self.CLASSES: if name in self.CLASSES:
class_sample_idx[name].append(idx) cat_ids.append(self.cat2id[name])
return class_sample_idx return cat_ids
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotations from ann_file. """Load annotations from ann_file.
......
...@@ -259,9 +259,9 @@ class DataBaseSampler(object): ...@@ -259,9 +259,9 @@ class DataBaseSampler(object):
count += 1 count += 1
s_points_list.append(s_points) s_points_list.append(s_points)
# gt_names = np.array([s['name'] for s in sampled]),
# gt_labels = np.array([self.cat2label(s) for s in gt_names]) gt_labels = np.array([self.cat2label[s['name']] for s in sampled],
gt_labels = np.array([self.cat2label[s['name']] for s in sampled]) dtype=np.long)
ret = { ret = {
'gt_labels_3d': 'gt_labels_3d':
gt_labels, gt_labels,
......
...@@ -198,7 +198,7 @@ class ObjectSample(object): ...@@ -198,7 +198,7 @@ class ObjectSample(object):
input_dict['img'] = sampled_dict['img'] input_dict['img'] = sampled_dict['img']
input_dict['gt_bboxes_3d'] = gt_bboxes_3d input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d input_dict['gt_labels_3d'] = gt_labels_3d.astype(np.long)
input_dict['points'] = points input_dict['points'] = points
return input_dict return input_dict
......
...@@ -62,13 +62,13 @@ def test_getitem(): ...@@ -62,13 +62,13 @@ def test_getitem():
# and box_type_3d='Depth' in sunrgbd and scannet dataset. # and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR')) box_type_3d='LiDAR'))
nus_dataset = build_dataset(dataset_cfg) nus_dataset = build_dataset(dataset_cfg)
assert len(nus_dataset) == 10 assert len(nus_dataset) == 20
data = nus_dataset[0] data = nus_dataset[0]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5)
data = nus_dataset[1]
assert data['img_metas'].data['flip'] is False assert data['img_metas'].data['flip'] is False
assert data['img_metas'].data['pcd_horizontal_flip'] is False assert data['img_metas'].data['pcd_horizontal_flip'] is False
assert data['points']._data.shape == (901, 5)
data = nus_dataset[1]
assert data['img_metas'].data['flip'] is True
assert data['img_metas'].data['pcd_horizontal_flip'] is True
assert data['points']._data.shape == (537, 5) assert data['points']._data.shape == (537, 5)
...@@ -78,6 +78,7 @@ def test_object_sample(): ...@@ -78,6 +78,7 @@ def test_object_sample():
gt_labels.append(CLASSES.index(cat)) gt_labels.append(CLASSES.index(cat))
else: else:
gt_labels.append(-1) gt_labels.append(-1)
gt_labels = np.array(gt_labels, dtype=np.long)
input_dict = dict( input_dict = dict(
points=points, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels) points=points, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels)
input_dict = object_sample(input_dict) input_dict = object_sample(input_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment