from mmdet.datasets.builder import PIPELINES @PIPELINES.register_module() class PointSegClassMapping(object): """Map original semantic class to valid category ids. Map valid classes as 0~len(valid_cat_ids)-1 and others as len(valid_cat_ids). Args: valid_cat_ids (tuple[int): A tuple of valid category. """ def __init__(self, valid_cat_ids): self.valid_cat_ids = valid_cat_ids def __call__(self, results): assert 'pts_semantic_mask' in results pts_semantic_mask = results['pts_semantic_mask'] neg_cls = len(self.valid_cat_ids) for i in range(pts_semantic_mask.shape[0]): if pts_semantic_mask[i] in self.valid_cat_ids: converted_id = self.valid_cat_ids.index(pts_semantic_mask[i]) pts_semantic_mask[i] = converted_id else: pts_semantic_mask[i] = neg_cls results['pts_semantic_mask'] = pts_semantic_mask return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += '(valid_cat_ids={})'.format(self.valid_cat_ids) return repr_str