Commit 4bfa571d authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

Allow alternative gesture mappings in subclass fetcher

Summary: Add option to specify a custom subclass id mapping. Allows for flexibility when training models with different outputs needed.

Reviewed By: sanjeevk42

Differential Revision: D26826986

fbshipit-source-id: 9dba4f0f2f2afebd2f152ddd9aebd46cf4c86a0d
parent 8c3618d9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import logging
import numpy as np
import torch
......@@ -9,6 +12,7 @@ from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import (
D2GO_DATA_MAPPER_REGISTRY,
......@@ -17,22 +21,47 @@ from d2go.data.dataset_mappers import (
from d2go.utils.helper import alias
logger = logging.getLogger(__name__)
SUBCLASS_FETCHER_REGISTRY = Registry("SUBCLASS_FETCHER")
def add_subclass_configs(cfg):
_C = cfg
_C.MODEL.SUBCLASS = CN()
_C.MODEL.SUBCLASS.SUBCLASS_ON = False
_C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set
_C.MODEL.SUBCLASS.NUM_LAYERS = 1
_C.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER = "SubclassFetcher" # ABC, must be set
def fetch_subclass_from_extras(dataset_dict):
class SubclassFetcher(ABC):
""" Fetcher class to read subclass id annotations from dataset and prepare for train/eval.
Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator
to use with custom projects.
"""
Retrieve subclass (eg. hand gesture per RPN region) info from dataset dict.
@property
@abstractmethod
def subclass_names(self) -> List[str]:
""" Overwrite this member with any new mappings' subclass names, which
may be useful for specific evaluation purposes.
len(self.subclass_names) should be equal to the expected number
of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1).
"""
pass
def remap(self, subclass_id: int) -> int:
""" Map subclass ids read from dataset to new label id """
return subclass_id
def fetch_subclass_ids(self, dataset_dict: Dict[str, Any]) -> List[int]:
""" Get all the subclass_ids in a dataset dict """
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list]
return subclass_ids
@D2GO_DATA_MAPPER_REGISTRY.register()
class SubclassDatasetMapper(D2GoDatasetMapper):
"""
......@@ -40,7 +69,18 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
"""
def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None):
super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens)
self.subclass_fetcher = subclass_fetcher or fetch_subclass_from_extras
if subclass_fetcher is None:
fetcher_name = cfg.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER
self.subclass_fetcher = SUBCLASS_FETCHER_REGISTRY.get(fetcher_name)()
logger.info(
f"Initialized {self.__class__.__name__} with "
f"subclass fetcher '{self.subclass_fetcher.__class__.__name__}'"
)
else:
assert isinstance(subclass_fetcher, SubclassFetcher), subclass_fetcher
self.subclass_fetcher = subclass_fetcher
logger.info(f"Set subclass fetcher to {self.subclass_fetcher}")
# NOTE: field doesn't exist when loading a (old) caffe2 model.
# self.subclass_on = cfg.MODEL.SUBCLASS.SUBCLASS_ON
self.subclass_on = True
......@@ -53,7 +93,7 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
mapped_dataset_dict = super()._original_call(dataset_dict)
if (self.is_train and self.subclass_on):
subclass_ids = self.subclass_fetcher(dataset_dict)
subclass_ids = self.subclass_fetcher.fetch_subclass_ids(dataset_dict)
subclasses = torch.tensor(subclass_ids, dtype=torch.int64)
mapped_dataset_dict["instances"].gt_subclasses = subclasses
return mapped_dataset_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment