"alphafold/model/structure_module.py" did not exist on "3d9c2de3f1857072be78ee80396553a5347bda57"
Commit c23c4208 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support class-balanced sampling for nuScenes dataset

parent ea6cf247
......@@ -16,6 +16,8 @@ class NuScenesDataset(DatasetTemplate):
)
self.infos = []
self.include_nuscenes_data(self.mode)
if self.training and self.dataset_cfg.get('BALANCED_RESAMPLING', False):
self.infos = self.balanced_infos_resampling(self.infos)
def include_nuscenes_data(self, mode):
self.logger.info('Loading NuScenes dataset')
......@@ -32,6 +34,40 @@ class NuScenesDataset(DatasetTemplate):
self.infos.extend(nuscenes_infos)
self.logger.info('Total samples for NuScenes dataset: %d' % (len(nuscenes_infos)))
def balanced_infos_resampling(self, infos):
"""
Class-balanced sampling of nuScenes dataset from https://arxiv.org/abs/1908.09492
"""
cls_infos = {name: [] for name in self.class_names}
for info in infos:
for name in set(info['gt_names']):
if name in self.class_names:
cls_infos[name].append(info)
duplicated_samples = sum([len(v) for _, v in cls_infos.items()])
cls_dist = {k: len(v) / duplicated_samples for k, v in cls_infos.items()}
sampled_infos = []
frac = 1.0 / len(self.class_names)
ratios = [frac / v for v in cls_dist.values()]
for cur_cls_infos, ratio in zip(list(cls_infos.values()), ratios):
sampled_infos += np.random.choice(
cur_cls_infos, int(len(cur_cls_infos) * ratio)
).tolist()
self.logger.info('Total samples after balanced resampling: %s' % (len(sampled_infos)))
cls_infos_new = {name: [] for name in self.class_names}
for info in sampled_infos:
for name in set(info['gt_names']):
if name in self.class_names:
cls_infos_new[name].append(info)
cls_dist_new = {k: len(v) / len(sampled_infos) for k, v in cls_infos_new.items()}
return sampled_infos
def get_sweep(self, sweep_info):
def remove_ego_points(points, center_radius=1.0):
mask = ~((np.abs(points[:, 0]) < center_radius) & (np.abs(points[:, 1]) < center_radius))
......
......@@ -17,6 +17,8 @@ INFO_PATH: {
POINT_CLOUD_RANGE: [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
BALANCED_RESAMPLING: True
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
......
......@@ -221,8 +221,8 @@ MODEL:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.2
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 250
NMS_PRE_MAXSIZE: 1000
NMS_POST_MAXSIZE: 100
OPTIMIZATION:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment