Commit 4bfa571d authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

Allow alternative gesture mappings in subclass fetcher

Summary: Add option to specify a custom subclass id mapping. Allows for flexibility when training models with different outputs needed.

Reviewed By: sanjeevk42

Differential Revision: D26826986

fbshipit-source-id: 9dba4f0f2f2afebd2f152ddd9aebd46cf4c86a0d
parent 8c3618d9
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -9,6 +12,7 @@ from torch.nn import functional as F ...@@ -9,6 +12,7 @@ from torch.nn import functional as F
from detectron2.layers import cat from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import ( from d2go.data.dataset_mappers import (
D2GO_DATA_MAPPER_REGISTRY, D2GO_DATA_MAPPER_REGISTRY,
...@@ -17,21 +21,46 @@ from d2go.data.dataset_mappers import ( ...@@ -17,21 +21,46 @@ from d2go.data.dataset_mappers import (
from d2go.utils.helper import alias from d2go.utils.helper import alias
logger = logging.getLogger(__name__)
SUBCLASS_FETCHER_REGISTRY = Registry("SUBCLASS_FETCHER")
def add_subclass_configs(cfg): def add_subclass_configs(cfg):
_C = cfg _C = cfg
_C.MODEL.SUBCLASS = CN() _C.MODEL.SUBCLASS = CN()
_C.MODEL.SUBCLASS.SUBCLASS_ON = False _C.MODEL.SUBCLASS.SUBCLASS_ON = False
_C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set _C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set
_C.MODEL.SUBCLASS.NUM_LAYERS = 1 _C.MODEL.SUBCLASS.NUM_LAYERS = 1
_C.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER = "SubclassFetcher" # ABC, must be set
def fetch_subclass_from_extras(dataset_dict): class SubclassFetcher(ABC):
""" """ Fetcher class to read subclass id annotations from dataset and prepare for train/eval.
Retrieve subclass (eg. hand gesture per RPN region) info from dataset dict. Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator
to use with custom projects.
""" """
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list] @property
return subclass_ids @abstractmethod
def subclass_names(self) -> List[str]:
""" Overwrite this member with any new mappings' subclass names, which
may be useful for specific evaluation purposes.
len(self.subclass_names) should be equal to the expected number
of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1).
"""
pass
def remap(self, subclass_id: int) -> int:
""" Map subclass ids read from dataset to new label id """
return subclass_id
def fetch_subclass_ids(self, dataset_dict: Dict[str, Any]) -> List[int]:
""" Get all the subclass_ids in a dataset dict """
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list]
return subclass_ids
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
class SubclassDatasetMapper(D2GoDatasetMapper): class SubclassDatasetMapper(D2GoDatasetMapper):
...@@ -40,7 +69,18 @@ class SubclassDatasetMapper(D2GoDatasetMapper): ...@@ -40,7 +69,18 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
""" """
def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None): def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None):
super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens) super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens)
self.subclass_fetcher = subclass_fetcher or fetch_subclass_from_extras if subclass_fetcher is None:
fetcher_name = cfg.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER
self.subclass_fetcher = SUBCLASS_FETCHER_REGISTRY.get(fetcher_name)()
logger.info(
f"Initialized {self.__class__.__name__} with "
f"subclass fetcher '{self.subclass_fetcher.__class__.__name__}'"
)
else:
assert isinstance(subclass_fetcher, SubclassFetcher), subclass_fetcher
self.subclass_fetcher = subclass_fetcher
logger.info(f"Set subclass fetcher to {self.subclass_fetcher}")
# NOTE: field doesn't exist when loading a (old) caffe2 model. # NOTE: field doesn't exist when loading a (old) caffe2 model.
# self.subclass_on = cfg.MODEL.SUBCLASS.SUBCLASS_ON # self.subclass_on = cfg.MODEL.SUBCLASS.SUBCLASS_ON
self.subclass_on = True self.subclass_on = True
...@@ -53,7 +93,7 @@ class SubclassDatasetMapper(D2GoDatasetMapper): ...@@ -53,7 +93,7 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
mapped_dataset_dict = super()._original_call(dataset_dict) mapped_dataset_dict = super()._original_call(dataset_dict)
if (self.is_train and self.subclass_on): if (self.is_train and self.subclass_on):
subclass_ids = self.subclass_fetcher(dataset_dict) subclass_ids = self.subclass_fetcher.fetch_subclass_ids(dataset_dict)
subclasses = torch.tensor(subclass_ids, dtype=torch.int64) subclasses = torch.tensor(subclass_ids, dtype=torch.int64)
mapped_dataset_dict["instances"].gt_subclasses = subclasses mapped_dataset_dict["instances"].gt_subclasses = subclasses
return mapped_dataset_dict return mapped_dataset_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment