Commit f23248c0 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: f4a8ba78691d8cf46e003ef0bd2e95f170932778
parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import atexit
import contextlib
import json
import logging
import os
import shutil
import tempfile
from collections import defaultdict
import numpy as np
import torch.utils.data as data
logger = logging.getLogger(__name__)
from d2go.config import temp_defrost
from d2go.data.datasets import register_dataset_split, ANN_FN, IM_DIR
from detectron2.data import DatasetCatalog, MetadataCatalog
from fvcore.common.file_io import PathManager
class AdhocDatasetManager:
# mapping from the new dataset name a AdhocDataset instance
_REGISTERED = {}
@staticmethod
def add(adhoc_ds):
assert isinstance(adhoc_ds, AdhocDataset)
if adhoc_ds.new_ds_name in AdhocDatasetManager._REGISTERED:
logger.warning(
"Adhoc dataset {} has already been added, skip adding it".format(
adhoc_ds.new_ds_name
)
)
else:
logger.info("Adding new adhoc dataset {} ...".format(adhoc_ds.new_ds_name))
AdhocDatasetManager._REGISTERED[adhoc_ds.new_ds_name] = adhoc_ds
adhoc_ds.register_catalog()
@staticmethod
def remove(adhoc_ds):
try:
assert isinstance(adhoc_ds, AdhocDataset)
if adhoc_ds.new_ds_name not in AdhocDatasetManager._REGISTERED:
logger.warning(
"Adhoc dataset {} has already been removed, skip removing it".format(
adhoc_ds.new_ds_name
)
)
else:
logger.info("Remove adhoc dataset {} ...".format(adhoc_ds.new_ds_name))
del AdhocDatasetManager._REGISTERED[adhoc_ds.new_ds_name]
finally:
adhoc_ds.cleanup()
@staticmethod
@atexit.register
def _atexit():
for ds in AdhocDatasetManager._REGISTERED.values():
logger.info("Remove remaining adhoc dataset: {}".format(ds.new_ds_name))
ds.cleanup()
class AdhocDataset(object):
def __init__(self, new_ds_name):
assert isinstance(new_ds_name, str)
self.new_ds_name = new_ds_name
def register_catalog(self):
raise NotImplementedError()
def cleanup(self):
raise NotImplementedError()
class CallFuncWithJsonFile(object):
"""
The instance of this class is parameterless callable that calls its `func` using its
`json_file`, it can be used to register in DatasetCatalog which later on provide
access to the json file.
"""
def __init__(self, func, json_file):
self.func = func
self.json_file = json_file
def __call__(self):
return self.func(self.json_file)
class AdhocCOCODataset(AdhocDataset):
def __init__(self, src_ds_name, new_ds_name):
super().__init__(new_ds_name)
# NOTE: only support single source dataset now
assert isinstance(src_ds_name, str)
self.src_ds_name = src_ds_name
def new_json_dict(self, json_dict):
raise NotImplementedError()
def register_catalog(self):
"""
Adhoc COCO (json) dataset assumes the derived dataset can be created by only
changing the json file, currently it supports two sources: 1) the dataset is
registered using standard COCO registering functions in D2 or
register_dataset_split from D2Go, this way it uses `json_file` from the metadata
to access the json file. 2) the load func in DatasetCatalog is an instance of
CallFuncWithJsonFile, which gives access to the json_file. In both cases,
metadata will be the same except for the `name` and potentially `json_file`.
"""
logger.info("Register {} from {}".format(self.new_ds_name, self.src_ds_name))
metadata = MetadataCatalog.get(self.src_ds_name)
load_func = DatasetCatalog[self.src_ds_name]
src_json_file = (
load_func.json_file
if isinstance(load_func, CallFuncWithJsonFile)
else metadata.json_file
)
# TODO cache ?
with PathManager.open(src_json_file) as f:
json_dict = json.load(f)
assert "images" in json_dict, "Only support COCO-style json!"
json_dict = self.new_json_dict(json_dict)
self.tmp_dir = tempfile.mkdtemp(prefix="detectron2go_tmp_datasets")
tmp_file = os.path.join(self.tmp_dir, "{}.json".format(self.new_ds_name))
with open(tmp_file, "w") as f:
json.dump(json_dict, f)
# re-register DatasetCatalog
if isinstance(load_func, CallFuncWithJsonFile):
new_func = CallFuncWithJsonFile(func=load_func.func, json_file=tmp_file)
DatasetCatalog.register(self.new_ds_name, new_func)
else:
# NOTE: only supports COCODataset as DS_TYPE since we cannot reconstruct
# the split_dict
register_dataset_split(
self.new_ds_name,
split_dict={ANN_FN: tmp_file, IM_DIR: metadata.image_root},
)
# re-regisister MetadataCatalog
metadata_dict = metadata.as_dict()
metadata_dict["name"] = self.new_ds_name
if "json_file" in metadata_dict:
metadata_dict["json_file"] = tmp_file
MetadataCatalog.remove(self.new_ds_name)
MetadataCatalog.get(self.new_ds_name).set(**metadata_dict)
def cleanup(self):
# remove temporarily registered dataset and json file
DatasetCatalog.pop(self.new_ds_name, None)
MetadataCatalog.pop(self.new_ds_name, None)
if hasattr(self, "tmp_dir"):
shutil.rmtree(self.tmp_dir)
class COCOSubsetWithNImages(AdhocCOCODataset):
_SUPPORTED_SAMPLING = ["frontmost", "random"]
def __init__(self, src_ds_name, num_images, sampling):
super().__init__(
src_ds_name=src_ds_name,
new_ds_name="{}_{}{}".format(src_ds_name, sampling, num_images),
)
self.num_images = num_images
self.sampling = sampling
def new_json_dict(self, json_dict):
all_images = json_dict["images"]
if self.sampling == "frontmost":
new_images = all_images[: self.num_images]
elif self.sampling == "random":
# use fixed seed so results are repeatable
indices = np.random.RandomState(seed=42).permutation(len(all_images))
new_images = [all_images[i] for i in indices[: self.num_images]]
else:
raise NotImplementedError(
"COCOSubsetWithNImages doesn't support sampling method: {}".format(
self.sampling
)
)
new_image_ids = {im["id"] for im in new_images}
new_annotations = [
ann for ann in json_dict["annotations"] if ann["image_id"] in new_image_ids
]
json_dict["images"] = new_images
json_dict["annotations"] = new_annotations
return json_dict
class COCOSubsetWithGivenImages(AdhocCOCODataset):
def __init__(self, src_ds_name, file_names, prefix="given"):
super().__init__(
src_ds_name=src_ds_name,
new_ds_name="{}_{}{}".format(src_ds_name, prefix, len(file_names)),
)
self.file_names = file_names
def new_json_dict(self, json_dict):
all_images = json_dict["images"]
file_name_to_im = {im["file_name"]: im for im in all_images}
new_images = [file_name_to_im[file_name] for file_name in self.file_names]
# re-assign image id to keep the order (COCO loads images by id order)
old_id_to_new_id = {im["id"]: i for i, im in enumerate(new_images)}
new_annotations = [
ann
for ann in json_dict["annotations"]
if ann["image_id"] in old_id_to_new_id
]
# update image id
for im in new_images:
im["id"] = old_id_to_new_id[im["id"]]
for anno in new_annotations:
anno["image_id"] = old_id_to_new_id[anno["image_id"]]
json_dict["images"] = new_images
json_dict["annotations"] = new_annotations
return json_dict
class COCOWithClassesToUse(AdhocCOCODataset):
def __init__(self, src_ds_name, classes_to_use):
super().__init__(
src_ds_name=src_ds_name,
new_ds_name="{}@{}classes".format(src_ds_name, len(classes_to_use)),
)
self.classes_to_use = classes_to_use
def new_json_dict(self, json_dict):
categories = json_dict["categories"]
new_categories = [
cat for cat in categories if cat["name"] in self.classes_to_use
]
new_category_ids = {cat["id"] for cat in new_categories}
new_annotations = [
ann
for ann in json_dict["annotations"]
if ann["category_id"] in new_category_ids
]
json_dict["categories"] = new_categories
json_dict["annotations"] = new_annotations
return json_dict
class ClipLengthGroupedDataset(data.IterableDataset):
"""
Batch data that have same clip length and similar aspect ratio.
In this implementation, images with same length and whose aspect
ratio < (or >) 1 will be batched together.
This makes training with different clip length possible and improves
training speed because the images then need less padding to form a batch.
"""
def __init__(self, dataset, batch_size):
"""
Args:
dataset: an iterable. Each element must be a dict with keys
"width" and "height", which will be used to batch data.
batch_size (int):
"""
self.dataset = dataset
self.batch_size = batch_size
self._buckets = defaultdict(list)
def __iter__(self):
for d in self.dataset:
clip_length = len(d["frames"])
h, w = d["height"], d["width"]
aspect_ratio_bucket_id = 0 if h > w else 1
bucket = self._buckets[(clip_length, aspect_ratio_bucket_id)]
bucket.append(d)
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
@contextlib.contextmanager
def register_sub_dataset_with_n_images(dataset_name, num_images, sampling):
"""
Temporarily register a sub-dataset created from `dataset_name`, with the first
`num_images` from it.
"""
# when `num_images` is not larger than 0, return original dataset
if num_images <= 0:
yield dataset_name
return
# only support coco for now
assert sampling in COCOSubsetWithNImages._SUPPORTED_SAMPLING
new_dataset = COCOSubsetWithNImages(dataset_name, num_images, sampling)
AdhocDatasetManager.add(new_dataset)
try:
yield new_dataset.new_ds_name
finally:
AdhocDatasetManager.remove(new_dataset)
@contextlib.contextmanager
def register_sub_dataset_with_given_images(*args, **kwargs):
new_dataset = COCOSubsetWithGivenImages(*args, **kwargs)
AdhocDatasetManager.add(new_dataset)
AdhocDatasetManager.add(new_dataset)
try:
yield new_dataset.new_ds_name
finally:
AdhocDatasetManager.remove(new_dataset)
@contextlib.contextmanager
def maybe_subsample_n_images(cfg, is_train=False):
"""
Create a new config whose train/test datasets only take a subsample of
`max_images` image. Use all images (non-op) when `max_images` <= 0.
"""
max_images = cfg.D2GO_DATA.TEST.MAX_IMAGES
sampling = cfg.D2GO_DATA.TEST.SUBSET_SAMPLING
with contextlib.ExitStack() as stack: # python 3.3+
new_splits = tuple(
stack.enter_context(
register_sub_dataset_with_n_images(ds, max_images, sampling)
)
for ds in (cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST)
)
new_cfg = cfg.clone()
with temp_defrost(new_cfg):
if is_train:
new_cfg.DATASETS.TRAIN = new_splits
else:
new_cfg.DATASETS.TEST = new_splits
yield new_cfg
def update_cfg_if_using_adhoc_dataset(cfg):
if cfg.D2GO_DATA.DATASETS.TRAIN_CATEGORIES:
new_train_datasets = [
COCOWithClassesToUse(name, cfg.D2GO_DATA.DATASETS.TRAIN_CATEGORIES)
for name in cfg.DATASETS.TRAIN
]
[AdhocDatasetManager.add(new_ds) for new_ds in new_train_datasets]
with temp_defrost(cfg):
cfg.DATASETS.TRAIN = tuple(ds.new_ds_name for ds in new_train_datasets)
if cfg.D2GO_DATA.DATASETS.TEST_CATEGORIES:
new_test_datasets = [
COCOWithClassesToUse(ds, cfg.D2GO_DATA.DATASETS.TEST_CATEGORIES)
for ds in cfg.DATASETS.TEST
]
[AdhocDatasetManager.add(new_ds) for new_ds in new_test_datasets]
with temp_defrost(cfg):
cfg.DATASETS.TEST = tuple(ds.new_ds_name for ds in new_test_datasets)
return cfg
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Similar to detectron2.engine.launch, may support a few more things:
- support for get_local_rank.
- support other backends like GLOO.
"""
import logging
import tempfile
import detectron2.utils.comm as comm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from d2go.config import CfgNode, temp_defrost
from d2go.utils.launch_environment import get_launch_environment
logger = logging.getLogger(__name__)
_LOCAL_RANK = 0
_NUM_PROCESSES_PER_MACHINE = 1
def _set_local_rank(local_rank):
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def _set_num_processes_per_machine(num_processes):
global _NUM_PROCESSES_PER_MACHINE
_NUM_PROCESSES_PER_MACHINE = num_processes
def get_local_rank():
return _LOCAL_RANK
def get_num_processes_per_machine():
return _NUM_PROCESSES_PER_MACHINE
def launch(
main_func,
num_processes_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
backend="NCCL",
always_spawn=False,
args=(),
):
logger.info(
f"Launch with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
f" dist_url: {dist_url}, backend: {backend}."
)
if get_launch_environment() == "local" and not torch.cuda.is_available():
assert len(args) > 0, args
cfg = args[0]
assert isinstance(cfg, CfgNode)
if cfg.MODEL.DEVICE == "cuda":
logger.warning(
"Detected that CUDA is not available on this machine, set MODEL.DEVICE"
" to cpu and backend to GLOO"
)
with temp_defrost(cfg):
cfg.MODEL.DEVICE = "cpu"
backend = "GLOO"
if backend == "NCCL":
assert (
num_processes_per_machine <= torch.cuda.device_count()
), "num_processes_per_machine is greater than device count: {} vs {}".format(
num_processes_per_machine, torch.cuda.device_count()
)
world_size = num_machines * num_processes_per_machine
if world_size > 1 or always_spawn:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes
prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return"
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f:
return_file = f.name
mp.spawn(
_distributed_worker,
nprocs=num_processes_per_machine,
args=(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0:
return torch.load(return_file)
else:
return main_func(*args)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert backend in ["NCCL", "GLOO"]
_set_local_rank(local_rank)
_set_num_processes_per_machine(num_processes_per_machine)
# NOTE: this is wrong if using different number of processes across machine
global_rank = machine_rank * num_processes_per_machine + local_rank
try:
dist.init_process_group(
backend=backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception as e:
logger.error("Process group URL: {}".format(dist_url))
raise e
# synchronize is needed here to prevent a possible timeout after calling
# init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
if backend in ["NCCL"]:
torch.cuda.set_device(local_rank)
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_processes_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_processes_per_machine, (i + 1) * num_processes_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
ret = main_func(*args)
if global_rank == 0:
logger.info(
"Save {}.{} return to: {}".format(
main_func.__module__, main_func.__name__, return_file
)
)
torch.save(ret, return_file)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .prediction_count_evaluation import PredictionCountEvaluator # noqa
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from collections import OrderedDict
import numpy as np
from detectron2.evaluation import DatasetEvaluator
logger = logging.getLogger(__name__)
class PredictionCountEvaluator(DatasetEvaluator):
"""
Custom Detectron2 evaluator class to simply count the number of predictions
e.g. on a dataset of hard negatives where there are no annotations, and
summarize results into interpretable metrics.
See class pattern from detectron2.evaluation.evaluator.py, especially
:func:`inference_on_dataset` to see how this class will be called.
"""
def reset(self):
self.prediction_counts = []
self.confidence_scores = []
def process(self, inputs, outputs):
"""
Params:
input: the input that's used to call the model.
output: the return value of `model(output)`
"""
# outputs format:
# [{
# "instances": Instances(
# num_instances=88,
# fields=[scores = tensor([list of len num_instances])]
# ), ...
# },
# ... other dicts
# ]
for output_dict in outputs:
instances = output_dict["instances"]
self.prediction_counts.append(len(instances))
self.confidence_scores.extend(instances.get("scores").tolist())
def evaluate(self):
"""
Returns:
In detectron2.tools.train_net.py, following format expected:
dict:
* key: the name of the task (e.g., bbox)
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
"""
mpi = np.mean(self.prediction_counts)
mcp = np.mean(self.confidence_scores)
output_metrics = OrderedDict(
{
"false_positives": {
"predictions_per_image": mpi,
"confidence_per_prediction": mcp,
}
}
)
logger.info(f"mean predictions per image: {mpi}")
logger.info(f"mean confidence per prediction: {mcp}")
return output_metrics
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import heapq
import itertools
import logging
from contextlib import contextmanager
from detectron2.data import MetadataCatalog
from detectron2.evaluation import DatasetEvaluator, SemSegEvaluator
from detectron2.utils.comm import all_gather, synchronize
logger = logging.getLogger(__name__)
class MultiSemSegEvaluator(DatasetEvaluator):
"""
Evaluate multiple results for the same target. SemSegEvaluator requires the
outputs of model to be like:
[
{"sem_seg": Tensor},
]
This evaluator allows evaluating mutliple predictions, it may takes outputs like:
[
{
"prediction_1": {"sem_seg": Tensor},
"prediction_2": {"sem_seg": Tensor},
}
]
"""
_DUMMY_KEY_PREFIX = "dummy_eval"
def __init__(self, dataset_name, *args, distributed, output_dir=None, **kwargs):
self._distributed = distributed
self._output_dir = output_dir
self.evaluators = {}
self.dataset_name = dataset_name
self.init_args = args
self.init_kwargs = kwargs
def _get_evaluator(self, key, superclass_name=None):
if key in self.evaluators:
return self.evaluators[key]
def create_evaluator_and_reset(dataset_name):
logger.info(
"Create an instance of SemSegEvaluator for {} on dataset {} ...".format(
key, dataset_name
)
)
evaluator = SemSegEvaluator(
dataset_name,
*self.init_args,
**self.init_kwargs,
distributed=self._distributed,
output_dir=self._output_dir,
)
evaluator.reset()
return evaluator
if superclass_name is None:
self.evaluators[key] = create_evaluator_and_reset(self.dataset_name)
else:
# NOTE: create temporary single-super-class dataset and use standard
# evaluator for the dataset
metadata = MetadataCatalog.get(self.dataset_name)
tmp_dataset_name = "__AUTOGEN__{}@{}".format(
self.dataset_name, superclass_name
)
from d2go.datasets.builtin_dataset_people_ai_person_segmentation import (
_register_person_sem_seg,
)
if tmp_dataset_name not in MetadataCatalog:
_register_person_sem_seg(
tmp_dataset_name,
metadata.mcs_metadata[superclass_name],
image_root=metadata.image_root,
sem_seg_root=metadata.sem_seg_root,
instances_json=metadata.json_file,
mask_dir="{}_mask".format(superclass_name),
)
self.evaluators[key] = create_evaluator_and_reset(tmp_dataset_name)
return self.evaluators[key]
def reset(self):
for evaluator in self.evaluators.values():
evaluator.reset()
def process(self, inputs, outputs):
if "sem_seg" in outputs[0].keys():
# normal eval is compatible with SemSegEvaluator
self._get_evaluator("sem_seg").process(inputs, outputs)
else:
# only the file_name of inputs is needed for SemSegEvaluator
inputs_ = [{"file_name": inp["file_name"]} for inp in inputs]
for frame_name in outputs[0].keys():
if isinstance(outputs[0]["detect"]["sem_seg"], dict): # multi-class
for superclass_name in outputs[0]["detect"]["sem_seg"]:
outputs_ = []
for outp in outputs:
x = outp[frame_name]
x = {"sem_seg": x["sem_seg"][superclass_name]}
outputs_.append(x)
self._get_evaluator(
"sem_seg-{}-{}".format(frame_name, superclass_name),
superclass_name=superclass_name,
).process(inputs_, outputs_)
else:
# convert the output to SemSegEvaluator's format
outputs_ = [outp[frame_name] for outp in outputs]
self._get_evaluator("sem_seg-{}".format(frame_name)).process(
inputs_, outputs_
)
def evaluate(self):
results = {}
# The evaluation will get stuck sometimes if the follwoing code is not used.
# `SemSegEvaluator` will do synchronization between processes when computing
# the metrics. In some cases the number of self.evaluators will not be the
# same between processes and the code will stuck in synchronization.
# For example, evaluate 10 images on 8 GPUs, only 5 GPUs
# will be used for evaluation, each has 2 images, the rest 3 GPUs will have
# zero self.evaluators as they are constructed on-the-fly when calling
# self.process())
# We create additional evaluators so that all processes have the same size
# of evaluators so that the synchronization will not get stuck.
evaluator_size = len(self.evaluators)
synchronize()
evaluator_size_list = all_gather(evaluator_size)
max_evaluator_size = max(evaluator_size_list)
if evaluator_size < max_evaluator_size:
# create additional evaluators so that all processes have the same
# size of evaluators
metadata = MetadataCatalog.get(self.dataset_name)
mcs_metadata = metadata.get("mcs_metadata")
for idx in range(max_evaluator_size - evaluator_size):
dummy_key = f"{self._DUMMY_KEY_PREFIX}_{idx}"
assert dummy_key not in self.evaluators
if mcs_metadata:
for k in mcs_metadata:
self._get_evaluator(dummy_key, superclass_name=k).reset()
else:
self._get_evaluator(dummy_key).reset()
for name, evaluator in self.evaluators.items():
result = evaluator.evaluate()
# NOTE: .evaluate() returns None for non-main process
if result is not None:
results[name] = result["sem_seg"]
return results
class MultiSemSegVidEvaluator(MultiSemSegEvaluator):
"""
Evaluate semantic segmentation results for video clips. MultiSemSegVidEvaluator
requires the outputs of model to be like:
[
{"file_names": Tensor},
]
"""
def process(self, inputs, outputs):
assert "file_names" in inputs[0]
inputs_ = []
for batch_id in range(len(inputs)):
for frame_i in range(len(inputs[batch_id]["file_names"])):
inputs_.append(
{"file_name": inputs[batch_id]["file_names"][frame_i]}
)
for name in outputs[0].keys():
# convert the output to SemSegEvaluator's format
outputs_ = [outp[name] for outp in outputs]
self.evaluators["sem_seg_{}".format(name)].process(inputs_, outputs_)
@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
"""
A context manager that will prevent any logging messages
triggered during the body from being processed.
:param highest_level: the maximum logging level in use.
This would only need to be changed if a custom level greater than CRITICAL
is defined.
"""
# two kind-of hacks here:
# * can't get the highest logging level in effect => delegate to the user
# * can't get the current module-level override => use an undocumented
# (but non-private!) interface
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
class PerImageEvaluator(object):
def __init__(
self,
evaluator,
callback,
distributed=True,
playback_criterion=None,
playback_limit=0,
):
self._evaluator = evaluator
self._evaluator._distributed = False
self._evaluator._output_dir = None
self._distributed = distributed
self.callback = callback
self.results_per_image = []
# record the N most interesting results for playback
self.playback_heap = []
self.playback_criterion = playback_criterion
self.playback_limit = playback_limit
def reset(self):
self._evaluator.reset()
def process(self, inputs, outputs):
self._evaluator.process(inputs, outputs)
assert len(inputs) == 1
with all_logging_disabled():
result = self._evaluator.evaluate()
self.results_per_image.append((inputs[0], result))
if self.playback_criterion:
score = self.playback_criterion(result)
heapq.heappush(self.playback_heap, (score, inputs[0], outputs[0], result))
if len(self.playback_heap) > self.playback_limit:
heapq.heappop(self.playback_heap)
self._evaluator.reset()
def evaluate(self):
if self._distributed:
synchronize()
results_per_image = all_gather(self.results_per_image)
self.results_per_image = list(itertools.chain(*results_per_image))
playback_heap = all_gather(self.playback_heap)
playback_heap = list(itertools.chain(*playback_heap))
# each GPU has local N mininums, sort and take global mininums
playback_heap = sorted(playback_heap, key=lambda x: x[0])
self.playback_heap = playback_heap[: self.playback_limit]
self.callback(self)
return {}
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
API for exporting a pytorch model to a predictor, the predictor contains model(s) in
deployable format and predifined functions as glue code. The exported predictor should
generate same output as the original pytorch model. (See predictor/api.py for details of
predictor)
This API defines two customizable methods for the pytorch model:
prepare_for_export (required by the default export_predictor): returns
PredictorExportConfig which tells information about how export the predictor.
export_predictor (optional): the implementation of export process. The default
implementation is provided to cover the majority of use cases where the
individual model(s) can be exported in standard way.
NOTE:
1: There's a difference between predictor type and model type. model type
refers to predifined deployable format such as caffe2, torchscript(_int8),
while the predictor type can be anything that "export_predictor" can
recognize.
2: The standard model exporting methods are provided by the library code, they're
meant to be modularized and can be used by customized export_predictor as well.
"""
import json
import logging
import os
from typing import Any, Callable, Dict, NamedTuple, Optional, Union
import torch
import torch.nn as nn
import torch.quantization.quantize_fx
from d2go.modeling.quantization import post_training_quantize
from fvcore.common.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import (
IdentityPostprocess,
IdentityPreprocess,
NaiveRunFunc,
)
logger = logging.getLogger(__name__)
class PredictorExportConfig(NamedTuple):
"""
Storing information for exporting a predictor.
Args:
model (any nested iterable structure of nn.Module): the model(s) to be exported
(via tracing/onnx or scripting). This can be sub-model(s) when the predictor
consists of multiple models in deployable format, and/or pre/post processing
is excluded due to requirement of tracing or hardward incompatibility.
data_generator (Callable): a function to generate all data needed for tracing,
such that data = data_generator(x), the returned data has the same nested
structure as model. The data for each model will be treated as positional
arguments, i.e. model(*data).
model_export_kwargs (Dict): additional kwargs when exporting each sub-model, it
follows the same nested structure as the model, and may contains information
such as scriptable.
preprocess_info (FuncInfo): info for predictor's preprocess
postprocess_info (FuncInfo): info for predictor's postprocess
run_func_info (FuncInfo): info for predictor's run_fun
"""
model: Union[nn.Module, Any]
# Shall we save data_generator in the predictor? This might be necessary when data
# is needed, eg. running benchmark for sub models
data_generator: Optional[Callable] = None
model_export_kwargs: Optional[Union[Dict, Any]] = None
preprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPreprocess, params={})
postprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPostprocess, params={})
run_func_info: FuncInfo = FuncInfo.gen_func_info(NaiveRunFunc, params={})
def convert_and_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader
):
"""
Entry point for convert and export model. This involes two steps:
- convert: converting the given `pytorch_model` to another format, currently
mainly for quantizing the model.
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
if "int8" in predictor_type:
if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info(
"The model is not quantized during training, running post"
" training quantization ..."
)
pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader)
# only check bn exists in ptq as qat still has bn inside fused ops
assert not fuse_utils.check_bn_exist(pytorch_model)
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if cfg.QUANTIZATION.EAGER_MODE:
# TODO(future diff): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
if hasattr(pytorch_model, 'prepare_for_quant_convert'):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else:
# TODO(future diff): move this to a default function
pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model))
return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
def export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader):
"""
Interface for exporting a pytorch model to predictor of given type. This function
can be override to arhieve customized exporting procedure, eg. using non-default
optimization passes, composing traced models, etc.
Args:
cfg (CfgNode): the config
pytorch_model (nn.Module): a pytorch model, mostly also a meta-arch
predictor_type (str): a string which specifies the type of predictor, note that
the definition of type is interpreted by "export_predictor", the default
implementation uses the deployable model format (eg. caffe2_fp32,
torchscript_int8) as predictor type.
output_dir (str): the parent directory where the predictor will be saved
data_loader: data loader for the pytorch model
Returns:
predictor_path (str): the directory of exported predictor, a sub-directory of
"output_dir"
"""
# predictor exporting can be customized by implement "export_predictor" of meta-arch
if hasattr(pytorch_model, "export_predictor"):
return pytorch_model.export_predictor(
cfg, predictor_type, output_dir, data_loader
)
else:
return default_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader
)
def default_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader
):
# The default implementation acts based on the PredictorExportConfig returned by
# calling "prepare_for_export". It'll export all sub models in standard way
# according to the "predictor_type".
assert hasattr(pytorch_model, "prepare_for_export"), pytorch_model
inputs = next(iter(data_loader))
export_config = pytorch_model.prepare_for_export(
cfg, inputs, export_scheme=predictor_type
)
predictor_path = os.path.join(output_dir, predictor_type)
PathManager.mkdirs(predictor_path)
# TODO: also support multiple models from nested dict in the default implementation
assert isinstance(export_config.model, nn.Module), "Currently support single model"
model = export_config.model
input_args = (
export_config.data_generator(inputs)
if export_config.data_generator is not None
else None
)
model_export_kwargs = export_config.model_export_kwargs or {}
# the default implementation assumes model type is the same as the predictor type
model_type = predictor_type
model_path = predictor_path # maye be sub dir for multipe models
standard_model_export(
model,
model_type=model_type,
save_path=model_path,
input_args=input_args,
**model_export_kwargs,
)
model_rel_path = os.path.relpath(model_path, predictor_path)
# assemble predictor
predictor_info = PredictorInfo(
model=ModelInfo(path=model_rel_path, type=model_type),
preprocess_info=export_config.preprocess_info,
postprocess_info=export_config.postprocess_info,
run_func_info=export_config.run_func_info,
)
with PathManager.open(
os.path.join(predictor_path, "predictor_info.json"), "w"
) as f:
json.dump(predictor_info.to_dict(), f, indent=4)
return predictor_path
# TODO: determine if saving data should be part of standard_model_export or not.
# TODO: determine how to support PTQ, option 1): do everything inside this function,
# drawback: needs data loader; no customization. option 2): do calibration outside,
# and only do tracing inside (same as fp32 torchscript model).
# TODO: define the supported model types, current caffe2/torchscript/torchscript_int8
# is not enough.
# TODO: determine if registry is needed (probably not since we only need to support
# a few known formats) as libarary code.
def standard_model_export(model, model_type, save_path, input_args, **kwargs):
if model_type.startswith("torchscript"):
from d2go.export.torchscript import trace_and_save_torchscript
trace_and_save_torchscript(model, input_args, save_path, **kwargs)
elif model_type == "caffe2":
from d2go.export.caffe2 import export_caffe2
# TODO: export_caffe2 depends on D2, need to make a copy of the implemetation
# TODO: support specifying optimization pass via kwargs
export_caffe2(model, input_args[0], save_path, **kwargs)
else:
raise NotImplementedError("Incorrect model_type: {}".format(model_type))
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch
import os
from torch import nn
from typing import Dict, Tuple
from detectron2.export.api import Caffe2Model
from detectron2.export.caffe2_export import (
export_caffe2_detection_model,
run_and_save_graph,
)
from d2go.export.logfiledb import export_to_logfiledb
logger = logging.getLogger(__name__)
def export_caffe2(
caffe2_compatible_model: nn.Module,
tensor_inputs: Tuple[str, torch.Tensor],
output_dir: str,
save_pb: bool = True,
save_logdb: bool = False,
) -> Tuple[Caffe2Model, Dict[str, str]]:
predict_net, init_net = export_caffe2_detection_model(
caffe2_compatible_model,
# pyre-fixme[6]: Expected `List[torch.Tensor]` for 2nd param but got
# `Tuple[str, torch.Tensor]`.
tensor_inputs,
)
caffe2_model = Caffe2Model(predict_net, init_net)
caffe2_export_paths = {}
if save_pb:
caffe2_model.save_protobuf(output_dir)
caffe2_export_paths.update({
"predict_net_path": os.path.join(output_dir, "model.pb"),
"init_net_path": os.path.join(output_dir, "model_init.pb"),
})
graph_save_path = os.path.join(output_dir, "model_def.svg")
ws_blobs = run_and_save_graph(
predict_net,
init_net,
tensor_inputs,
graph_save_path=graph_save_path,
)
caffe2_export_paths.update({
"model_def_path": graph_save_path,
})
if save_logdb:
logfiledb_path = os.path.join(output_dir, "model.logfiledb")
export_to_logfiledb(predict_net, init_net, logfiledb_path, ws_blobs)
caffe2_export_paths.update({
"logfiledb_path": logfiledb_path if save_logdb else None,
})
return caffe2_model, caffe2_export_paths
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from detectron2.export.caffe2_inference import ProtobufDetectionModel
from d2go.config import temp_defrost
logger = logging.getLogger(__name__)
def infer_mask_on(model: ProtobufDetectionModel):
# the real self.assembler should tell about this, currently use heuristic
possible_blob_names = {"mask_fcn_probs"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def infer_keypoint_on(model: ProtobufDetectionModel):
# the real self.assembler should tell about this, currently use heuristic
possible_blob_names = {"kps_score"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def infer_densepose_on(model: ProtobufDetectionModel):
possible_blob_names = {"AnnIndex", "Index_UV", "U_estimated", "V_estimated"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def _update_if_true(cfg, key, value):
if not value:
return
keys = key.split(".")
ref_value = cfg
while len(keys):
ref_value = getattr(ref_value, keys.pop(0))
if ref_value != value:
logger.warning(
"There's conflict between cfg and model, overwrite config {} from {} to {}"
.format(key, ref_value, value)
)
cfg.merge_from_list([key, value])
def update_cfg_from_pb_model(cfg, model):
"""
Update cfg statically based given caffe2 model, in cast that there's conflict
between caffe2 model and the cfg, caffe2 model has higher priority.
"""
with temp_defrost(cfg):
_update_if_true(cfg, "MODEL.MASK_ON", infer_mask_on(model))
_update_if_true(cfg, "MODEL.KEYPOINT_ON", infer_keypoint_on(model))
_update_if_true(cfg, "MODEL.DENSEPOSE_ON", infer_densepose_on(model))
return cfg
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from functools import lru_cache
import torch
from d2go.export.api import PredictorExportConfig
from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
convert_batched_inputs_to_c2_format,
)
from detectron2.export.shared import get_pb_arg_vali, get_pb_arg_vals
from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN
from detectron2.modeling.postprocessing import detector_postprocess
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
wrap_quant_subclass,
QuantWrapper,
)
from mobile_cv.predictor.api import FuncInfo
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export
logger = logging.getLogger(__name__)
@lru_cache() # only call once
def patch_d2_meta_arch():
# HACK: inject prepare_for_export for all D2's meta-arch
for cls_obj in META_ARCH_REGISTRY._obj_map.values():
if cls_obj.__module__.startswith("detectron2."):
if hasattr(cls_obj, "prepare_for_export"):
assert cls_obj.prepare_for_export == d2_meta_arch_prepare_for_export
else:
cls_obj.prepare_for_export = d2_meta_arch_prepare_for_export
if hasattr(cls_obj, "prepare_for_quant"):
assert cls_obj.prepare_for_quant == d2_meta_arch_prepare_for_quant
else:
cls_obj.prepare_for_quant = d2_meta_arch_prepare_for_quant
def _apply_eager_mode_quant(cfg, model):
if isinstance(model, GeneralizedRCNN):
""" Wrap each quantized part of the model to insert Quant and DeQuant in-place """
# Wrap backbone and proposal_generator
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
model.proposal_generator.rpn_head = wrap_quant_subclass(
model.proposal_generator.rpn_head,
n_inputs=len(cfg.MODEL.RPN.IN_FEATURES),
n_outputs=len(cfg.MODEL.RPN.IN_FEATURES) * 2,
)
# Wrap the roi_heads, box_pooler is not quantized
model.roi_heads.box_head = wrap_quant_subclass(
model.roi_heads.box_head,
n_inputs=1,
n_outputs=1,
)
model.roi_heads.box_predictor = wrap_quant_subclass(
model.roi_heads.box_predictor, n_inputs=1, n_outputs=2
)
# Optionally wrap keypoint and mask heads, pools are not quantized
if hasattr(model.roi_heads, "keypoint_head"):
model.roi_heads.keypoint_head = wrap_quant_subclass(
model.roi_heads.keypoint_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
if hasattr(model.roi_heads, "mask_head"):
model.roi_heads.mask_head = wrap_quant_subclass(
model.roi_heads.mask_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
# StandardROIHeadsWithSubClass uses a subclass head
if hasattr(model.roi_heads, "subclass_head"):
q_subclass_head = QuantWrapper(model.roi_heads.subclass_head)
model.roi_heads.subclass_head = q_subclass_head
else:
raise NotImplementedError(
"Eager mode for {} is not supported".format(type(model))
)
# TODO: wrap the normalizer and make it quantizable
# NOTE: GN is not quantizable, assuming all GN follows a quantized conv,
# wrap them with dequant-quant
model = wrap_non_quant_group_norm(model)
return model
def d2_meta_arch_prepare_for_quant(self, cfg):
model = self
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(model, inplace=True)
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
return model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import numpy as np
from mobile_cv.torch.utils_caffe2.ws_utils import ScopedWS
logger = logging.getLogger(__name__)
# NOTE: specific export_to_db for (data, im_info) dual inputs.
# modified from mobile-vision/common/utils/model_utils.py
def export_to_db(net, params, inputs, outputs, out_file, net_type=None, shapes=None):
# NOTE: special handling for im_info: by default the "predict_init_net"
# will zero_fill inputs/outputs (https://fburl.com/diffusion/nvksomrt),
# however the actual value of "im_info" also matters, so we need use
# extra_init_net to handle this.
import numpy as np
from caffe2.python import core
assert len(inputs) == 2
data_name, im_info_name = inputs
data_shape = shapes[data_name] # assume NCHW
extra_init_net = core.Net("extra_init_net")
im_info = np.array(
[[data_shape[2], data_shape[3], 1.0] for _ in range(data_shape[0])],
dtype=np.float32,
)
extra_init_net.GivenTensorFill(
[], im_info_name, shape=shapes[im_info_name], values=im_info
)
from caffe2.caffe2.fb.predictor import predictor_exporter # NOTE: slow import
predictor_export_meta = predictor_exporter.PredictorExportMeta(
predict_net=net,
parameters=params,
inputs=inputs,
outputs=outputs,
net_type=net_type,
shapes=shapes,
extra_init_net=extra_init_net,
)
logger.info("Writing logdb {} ...".format(out_file))
predictor_exporter.save_to_db(
db_type="log_file_db",
db_destination=out_file,
predictor_export_meta=predictor_export_meta,
)
def export_to_logfiledb(predict_net, init_net, outfile, ws_blobs):
logger.info("Exporting Caffe2 model to {}".format(outfile))
shapes = {
b: data.shape if isinstance(data, np.ndarray)
# proivde a dummpy shape if it could not be inferred
else [1]
for b, data in ws_blobs.items()
}
with ScopedWS("__ws_tmp__", is_reset=True) as ws:
ws.RunNetOnce(init_net)
initialized_blobs = set(ws.Blobs())
uninitialized = [
inp for inp in predict_net.external_input if inp not in initialized_blobs
]
params = list(initialized_blobs)
output_names = list(predict_net.external_output)
export_to_db(
predict_net, params, uninitialized, output_names, outfile, shapes=shapes
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
from typing import Tuple, Optional, Dict
import torch
from fvcore.common.file_io import PathManager
from torch import nn
logger = logging.getLogger(__name__)
def trace_and_save_torchscript(
model: nn.Module,
inputs: Tuple[torch.Tensor],
output_path: str,
_extra_files: Optional[Dict[str, bytes]] = None,
):
logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
# TODO: patch_builtin_len depends on D2, we should either copy the function or
# dynamically registering the D2's version.
from detectron2.export.torchscript_patch import patch_builtin_len
with torch.no_grad(), patch_builtin_len():
script_model = torch.jit.trace(model, inputs)
if _extra_files is None:
_extra_files = {}
model_file = os.path.join(output_path, "model.jit")
PathManager.mkdirs(output_path)
with PathManager.open(model_file, "wb") as f:
torch.jit.save(script_model, f, _extra_files=_extra_files)
data_file = os.path.join(output_path, "data.pth")
with PathManager.open(data_file, "wb") as f:
torch.save(inputs, f)
# NOTE: new API doesn't require return
return model_file
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import time
SETUP_ENV_TIME = []
REGISTER_D2_DATASETS_TIME = []
REGISTER_TIME = []
def _record_times(time_list):
def warp(f):
def timed_f(*args, **kwargs):
start = time.perf_counter()
ret = f(*args, **kwargs)
time_list.append(time.perf_counter() - start)
return ret
return timed_f
return warp
@_record_times(SETUP_ENV_TIME)
def _setup_env():
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from detectron2.utils.env import ( # noqa F401 isort:skip
setup_environment as d2_setup_environment,
)
@_record_times(REGISTER_D2_DATASETS_TIME)
def _register_d2_datasets():
# this will register D2 builtin datasets
import detectron2.data # noqa F401
@_record_times(REGISTER_TIME)
def _register():
from d2go.modeling.backbone import ( # NOQA
fbnet_v2,
)
from d2go.data import dataset_mappers # NOQA
from d2go.data.datasets import (
register_json_datasets,
register_builtin_datasets,
)
#register_json_datasets()
#register_builtin_datasets()
def initialize_all():
# exclude torch from timing
from torchvision.ops import nms # noqa
_setup_env()
_register_d2_datasets()
_register()
_INITIALIZED = False
if not _INITIALIZED:
initialize_all()
_INITIALIZED = True
# Copyright (c) Facebook, Inc. and its affiliates.
import os
from typing import Optional
import pkg_resources
import torch
from detectron2.checkpoint import DetectionCheckpointer
from d2go.runner import create_runner
class _ModelZooUrls(object):
"""
Mapping from names to officially released D2Go pre-trained models.
"""
S3_PREFIX = "https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/"
CONFIG_PATH_TO_URL_SUFFIX = {
"faster_rcnn_fbnetv3a_C4.yaml": "246823121/model_0479999.pth",
"faster_rcnn_fbnetv3a_dsmask_C4.yaml": "250414811/model_0399999.pth",
"faster_rcnn_fbnetv3g_fpn.yaml": "250356938/model_0374999.pth",
"mask_rcnn_fbnetv3a_C4.yaml": "250355374/model_0479999.pth",
"mask_rcnn_fbnetv3a_dsmask_C4.yaml": "250414867/model_0399999.pth",
"mask_rcnn_fbnetv3g_fpn.yaml": "250376154/model_0404999.pth",
"keypoint_rcnn_fbnetv3a_dsmask_C4.yaml": "250430934/model_0389999.pth",
}
def get_checkpoint_url(config_path):
"""
Returns the URL to the model trained using the given config
Args:
config_path (str): config file name relative to d2go's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: a URL to the model
"""
name = config_path.replace(".yaml", "")
if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX:
suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path]
return _ModelZooUrls.S3_PREFIX + suffix
raise RuntimeError("{} not available in Model Zoo!".format(name))
def get_config_file(config_path):
"""
Returns path to a builtin config file.
Args:
config_path (str): config file name relative to d2go's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: the real path to the config file.
"""
cfg_file = pkg_resources.resource_filename(
"d2go.model_zoo", os.path.join("configs", config_path)
)
if not os.path.exists(cfg_file):
raise RuntimeError("{} not available in Model Zoo!".format(config_path))
return cfg_file
def get_config(config_path, trained: bool = False, runner="d2go.runner.GeneralizedRCNNRunner"):
"""
Returns a config object for a model in model zoo.
Args:
config_path (str): config file name relative to d2go's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): If True, will set ``MODEL.WEIGHTS`` to trained model zoo weights.
If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used
instead; this will typically (though not always) initialize a subset of weights using
an ImageNet pre-trained model, while randomly initializing the other weights.
Returns:
CfgNode: a config object
"""
cfg_file = get_config_file(config_path)
runner = create_runner(runner)
cfg = runner.get_default_cfg()
cfg.merge_from_file(cfg_file)
if trained:
cfg.MODEL.WEIGHTS = get_checkpoint_url(config_path)
return cfg
def get(config_path, trained: bool = False, device: Optional[str] = None, runner="d2go.runner.GeneralizedRCNNRunner"):
"""
Get a model specified by relative path under Detectron2's official ``configs/`` directory.
Args:
config_path (str): config file name relative to d2go's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): see :func:`get_config`.
device (str or None): overwrite the device in config, if given.
Returns:
nn.Module: a d2go model. Will be in training mode.
Example:
::
from d2go import model_zoo
model = model_zoo.get("faster_rcnn_fbnetv3a_C4.yaml", trained=True)
"""
cfg = get_config(config_path, trained)
if device is not None:
cfg.MODEL.DEVICE = device
elif not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
runner = create_runner(runner)
model = runner.build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
return model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# NOTE: making necessary imports to register with Registery
from . import backbone # noqa
from . import modeldef # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.config import CfgNode as CN
def add_fbnet_default_configs(_C):
""" FBNet options and default values
"""
_C.MODEL.FBNET = CN()
_C.MODEL.FBNET.ARCH = "default"
# custom arch
_C.MODEL.FBNET.ARCH_DEF = ""
_C.MODEL.FBNET.BN_TYPE = "bn"
_C.MODEL.FBNET.NUM_GROUPS = 32 # for gn usage only
_C.MODEL.FBNET.SCALE_FACTOR = 1.0
# the output channels will be divisible by WIDTH_DIVISOR
_C.MODEL.FBNET.WIDTH_DIVISOR = 1
_C.MODEL.FBNET.DW_CONV_SKIP_BN = True
_C.MODEL.FBNET.DW_CONV_SKIP_RELU = True
# > 0 scale, == 0 skip, < 0 same dimension
_C.MODEL.FBNET.DET_HEAD_LAST_SCALE = 1.0
_C.MODEL.FBNET.DET_HEAD_BLOCKS = []
# overwrite the stride for the head, 0 to use original value
_C.MODEL.FBNET.DET_HEAD_STRIDE = 0
# > 0 scale, == 0 skip, < 0 same dimension
_C.MODEL.FBNET.KPTS_HEAD_LAST_SCALE = 0.0
_C.MODEL.FBNET.KPTS_HEAD_BLOCKS = []
# overwrite the stride for the head, 0 to use original value
_C.MODEL.FBNET.KPTS_HEAD_STRIDE = 0
# > 0 scale, == 0 skip, < 0 same dimension
_C.MODEL.FBNET.MASK_HEAD_LAST_SCALE = 0.0
_C.MODEL.FBNET.MASK_HEAD_BLOCKS = []
# overwrite the stride for the head, 0 to use original value
_C.MODEL.FBNET.MASK_HEAD_STRIDE = 0
# 0 to use all blocks defined in arch_def
_C.MODEL.FBNET.RPN_HEAD_BLOCKS = 0
_C.MODEL.FBNET.RPN_BN_TYPE = ""
# number of channels input to trunk
_C.MODEL.FBNET.STEM_IN_CHANNELS = 3
def add_fbnet_v2_default_configs(_C):
_C.MODEL.FBNET_V2 = CN()
_C.MODEL.FBNET_V2.ARCH = "default"
_C.MODEL.FBNET_V2.ARCH_DEF = []
# number of channels input to trunk
_C.MODEL.FBNET_V2.STEM_IN_CHANNELS = 3
_C.MODEL.FBNET_V2.SCALE_FACTOR = 1.0
# the output channels will be divisible by WIDTH_DIVISOR
_C.MODEL.FBNET_V2.WIDTH_DIVISOR = 1
# normalization configs
# name of norm such as "bn", "sync_bn", "gn"
_C.MODEL.FBNET_V2.NORM = "bn"
# for advanced use case that requries extra arguments, passing a list of
# dict such as [{"num_groups": 8}, {"momentum": 0.1}] (merged in given order).
# Note that string written it in .yaml will be evaluated by yacs, thus this
# node will become normal python object.
# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L410
_C.MODEL.FBNET_V2.NORM_ARGS = []
_C.MODEL.VT_FPN = CN()
_C.MODEL.VT_FPN.IN_FEATURES = ["res2", "res3", "res4", "res5"]
_C.MODEL.VT_FPN.OUT_CHANNELS = 256
_C.MODEL.VT_FPN.LAYERS = 3
_C.MODEL.VT_FPN.TOKEN_LS = [16, 16, 8, 8]
_C.MODEL.VT_FPN.TOKEN_C = 1024
_C.MODEL.VT_FPN.HEADS = 16
_C.MODEL.VT_FPN.MIN_GROUP_PLANES = 64
_C.MODEL.VT_FPN.NORM = "BN"
_C.MODEL.VT_FPN.POS_HWS = []
_C.MODEL.VT_FPN.POS_N_DOWNSAMPLE = []
def add_bifpn_default_configs(_C):
_C.MODEL.BIFPN = CN()
_C.MODEL.BIFPN.DEPTH_MULTIPLIER = 1
_C.MODEL.BIFPN.SCALE_FACTOR = 1
_C.MODEL.BIFPN.WIDTH_DIVISOR = 8
_C.MODEL.BIFPN.NORM = "bn"
_C.MODEL.BIFPN.NORM_ARGS = []
_C.MODEL.BIFPN.TOP_BLOCK_BEFORE_FPN = False
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import itertools
import logging
from typing import List
import torch
import torch.nn as nn
from detectron2.layers import ShapeSpec
from detectron2.modeling import (
BACKBONE_REGISTRY,
RPN_HEAD_REGISTRY,
Backbone,
build_anchor_generator,
)
from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool, LastLevelP6P7
from detectron2.modeling.roi_heads import box_head, keypoint_head, mask_head
from detectron2.utils.logger import log_first_n
from mobile_cv.arch.fbnet_v2 import fbnet_builder as mbuilder
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
from mobile_cv.arch.utils.helper import format_dict_expanding_list_values
from .modules import (
KeypointRCNNPredictor,
KeypointRCNNPredictorNoUpscale,
KeypointRCNNIRFPredictorNoUpscale,
KeypointRCNNConvUpsamplePredictorNoUpscale,
MaskRCNNConv1x1Predictor,
RPNHeadConvRegressor,
)
logger = logging.getLogger(__name__)
FBNET_BUILDER_IDENTIFIER = "fbnetv2"
def _get_builder_norm_args(cfg):
norm_name = cfg.MODEL.FBNET_V2.NORM
norm_args = {"name": norm_name}
assert all(isinstance(x, dict) for x in cfg.MODEL.FBNET_V2.NORM_ARGS)
for dic in cfg.MODEL.FBNET_V2.NORM_ARGS:
norm_args.update(dic)
return norm_args
def _merge_fbnetv2_arch_def(cfg):
arch_def = {}
assert all(isinstance(x, dict) for x in cfg.MODEL.FBNET_V2.ARCH_DEF)
for dic in cfg.MODEL.FBNET_V2.ARCH_DEF:
arch_def.update(dic)
return arch_def
def _parse_arch_def(cfg):
arch = cfg.MODEL.FBNET_V2.ARCH
arch_def = cfg.MODEL.FBNET_V2.ARCH_DEF
assert (arch != "" and not arch_def) ^ (not arch and arch_def != []), (
"Only allow one unset node between MODEL.FBNET_V2.ARCH ({}) and MODEL.FBNET_V2.ARCH_DEF ({})"
.format(arch, arch_def)
)
arch_def = FBNetV2ModelArch.get(arch) if arch else _merge_fbnetv2_arch_def(cfg)
# NOTE: arch_def is a dictionary describing the CNN architecture for creating
# the detection model. It can describe a wide range of models including the
# original FBNet. Each key-value pair expresses either a sub part of the model
# like trunk or head, or stores other meta information.
message = "Using un-unified arch_def for ARCH \"{}\" (without scaling):\n{}".format(
arch, format_dict_expanding_list_values(arch_def)
)
log_first_n(logging.INFO, message, n=1, key="message")
return arch_def
def _get_fbnet_builder_and_arch_def(cfg):
arch_def = _parse_arch_def(cfg)
# NOTE: one can store extra information in arch_def to configurate FBNetBuilder,
# after this point, builder and arch_def will become independent.
basic_args = arch_def.pop("basic_args", {})
builder = mbuilder.FBNetBuilder(
width_ratio=cfg.MODEL.FBNET_V2.SCALE_FACTOR,
width_divisor=cfg.MODEL.FBNET_V2.WIDTH_DIVISOR,
bn_args=_get_builder_norm_args(cfg),
)
builder.add_basic_args(**basic_args)
return builder, arch_def
def _get_stride_per_stage(blocks):
"""
Count the accummulated stride per stage given a list of blocks. The mbuilder
provides API for counting per-block accumulated stride, this function leverages
it to count per-stage accumulated stride.
Input: a list of blocks from the unified arch_def. Note that the stage_idx
must be contiguous (not necessarily starting from 0), and can be
non-ascending (not tested).
Output: a list of accumulated stride per stage, starting from lowest stage_idx.
"""
stride_per_block = mbuilder.count_stride_each_block(blocks)
assert len(stride_per_block) == len(blocks)
stage_idx_set = {s["stage_idx"] for s in blocks}
# assume stage idx are contiguous, eg. 1, 2, 3, ...
assert max(stage_idx_set) - min(stage_idx_set) + 1 == len(stage_idx_set)
start_stage_id = min(stage_idx_set)
ids_per_stage = [
[i for i, s in enumerate(blocks) if s["stage_idx"] == stage_idx]
for stage_idx in range(start_stage_id, start_stage_id + len(stage_idx_set))
] # eg. [[0], [1, 2], [3, 4, 5, 6], ...]
block_stride_per_stage = [
[stride_per_block[i] for i in ids] for ids in ids_per_stage
] # eg. [[1], [2, 1], [2, 1, 1, 1], ...]
stride_per_stage = [
list(itertools.accumulate(s, lambda x, y: x * y))[-1]
for s in block_stride_per_stage
] # eg. [1, 2, 2, ...]
accum_stride_per_stage = list(
itertools.accumulate(stride_per_stage, lambda x, y: x * y)
) # eg. [first*1, first*2, first*4, ...]
assert accum_stride_per_stage[-1] == mbuilder.count_strides(blocks)
return accum_stride_per_stage
def fbnet_identifier_checker(func):
""" Can be used to decorate _load_from_state_dict """
def wrapper(self, state_dict, prefix, *args, **kwargs):
possible_keys = [k for k in state_dict.keys() if k.startswith(prefix)]
if not all(FBNET_BUILDER_IDENTIFIER in k for k in possible_keys):
logger.warning(
"Couldn't match FBNetV2 pattern given prefix {}, possible keys: \n{}"
.format(prefix, "\n".join(possible_keys))
)
if any("xif" in k for k in possible_keys):
raise RuntimeError(
"Seems a FBNetV1 trained checkpoint is loaded by FBNetV2 model,"
" which is not supported. Please consider re-train your model"
" using the same setup as before (it will be FBNetV2). If you"
" need to run the old FBNetV1 models, those configs can be"
" still found, see D19477651 as example."
)
return func(self, state_dict, prefix, *args, **kwargs)
return wrapper
# pyre-fixme[11]: Annotation `Sequential` is not defined as a type.
class FBNetModule(nn.Sequential):
@fbnet_identifier_checker
def _load_from_state_dict(self, *args, **kwargs):
return super()._load_from_state_dict(*args, **kwargs)
def build_fbnet(cfg, name, in_channels):
"""
Create a FBNet module using FBNet V2 builder.
Args:
cfg (CfgNode): the config that contains MODEL.FBNET_V2.
name (str): the key in arch_def that represents a subpart of network
in_channels (int): input channel size
Returns:
nn.Sequential: the first return is a nn.Sequential, each element
corresponds a stage in arch_def.
List[ShapeSpec]: the second return is a list of ShapeSpec containing the
output channels and accumulated strides for that stage.
"""
builder, raw_arch_def = _get_fbnet_builder_and_arch_def(cfg)
# Reset the last_depth for this builder (might have been cached), this is
# the only mutable member variable.
builder.last_depth = in_channels
# NOTE: Each sub part of the model consists of several stages and each stage
# has several blocks. "Raw" arch_def (Dict[str, List[List[Tuple]]]) uses a
# list of stages to describe the architecture, which is more compact and
# thus written as builtin metadata (inside FBNetV2ModelArch) or config
# (MODEL.FBNET_V2.ARCH_DEF). "Unified" arch_def (Dict[str, List[Dict]])
# uses a list blocks from all stages instead, which is recognized by builder.
arch_def = mbuilder.unify_arch_def(raw_arch_def, [name])
arch_def = {name: arch_def[name]}
logger.info(
"Build FBNet using unified arch_def:\n{}"
.format(format_dict_expanding_list_values(arch_def))
)
arch_def_blocks = arch_def[name]
stages = []
trunk_stride_per_stage = _get_stride_per_stage(arch_def_blocks)
shape_spec_per_stage = []
for i, stride_i in enumerate(trunk_stride_per_stage):
stages.append(builder.build_blocks(
arch_def_blocks,
stage_indices=[i],
prefix_name=FBNET_BUILDER_IDENTIFIER + "_",
))
shape_spec_per_stage.append(ShapeSpec(
channels=builder.last_depth,
stride=stride_i,
))
return FBNetModule(*stages), shape_spec_per_stage
class FBNetV2Backbone(Backbone):
"""
Backbone (bottom-up) for FBNet.
Hierarchy:
trunk0:
xif0_0
xif0_1
...
trunk1:
xif1_0
xif1_1
...
...
Output features:
The outputs from each "stage", i.e. trunkX.
"""
def __init__(self, cfg):
super(FBNetV2Backbone, self).__init__()
stages, shape_specs = build_fbnet(
cfg,
name="trunk",
in_channels=cfg.MODEL.FBNET_V2.STEM_IN_CHANNELS
)
self._trunk_stage_names = []
self._trunk_stages = []
self._out_feature_channels = {}
self._out_feature_strides = {}
for i, (stage, shape_spec) in enumerate(zip(stages, shape_specs)):
name = "trunk{}".format(i)
self.add_module(name, stage)
self._trunk_stage_names.append(name)
self._trunk_stages.append(stage)
self._out_feature_channels[name] = shape_spec.channels
self._out_feature_strides[name] = shape_spec.stride
# returned features are the final output of each stage
self._out_features = self._trunk_stage_names
self._trunk_stage_names = tuple(self._trunk_stage_names)
def __prepare_scriptable__(self):
ret = copy.deepcopy(self)
ret._trunk_stages = nn.ModuleList(ret._trunk_stages)
for k in self._trunk_stage_names:
delattr(ret, k)
return ret
@fbnet_identifier_checker
def _load_from_state_dict(self, *args, **kwargs):
return super()._load_from_state_dict(*args, **kwargs)
# return features for each stage
def forward(self, x):
features = {}
for name, stage in zip(self._trunk_stage_names, self._trunk_stages):
x = stage(x)
features[name] = x
return features
class FBNetV2FPN(FPN):
"""
FPN module for FBNet.
"""
pass
def build_fbnet_backbone(cfg):
return FBNetV2Backbone(cfg)
@BACKBONE_REGISTRY.register()
class FBNetV2C4Backbone(Backbone):
def __init__(self, cfg, _):
super(FBNetV2C4Backbone, self).__init__()
self.body = build_fbnet_backbone(cfg)
self._out_features = self.body._out_features
self._out_feature_strides = self.body._out_feature_strides
self._out_feature_channels = self.body._out_feature_channels
def forward(self, x):
return self.body(x)
@BACKBONE_REGISTRY.register()
def FBNetV2FpnBackbone(cfg, _):
backbone = FBNetV2FPN(
bottom_up=build_fbnet_backbone(cfg),
in_features=cfg.MODEL.FPN.IN_FEATURES,
out_channels=cfg.MODEL.FPN.OUT_CHANNELS,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
)
return backbone
@BACKBONE_REGISTRY.register()
def FBNetV2RetinaNetBackbone(cfg, _):
bottom_up = build_fbnet_backbone(cfg)
in_channels_p6p7 = bottom_up.output_shape()[cfg.MODEL.FPN.IN_FEATURES[-1]].channels
top_block = LastLevelP6P7(in_channels_p6p7, cfg.MODEL.FPN.OUT_CHANNELS)
top_block.in_feature = cfg.MODEL.FPN.IN_FEATURES[-1]
backbone = FBNetV2FPN(
bottom_up=bottom_up,
in_features=cfg.MODEL.FPN.IN_FEATURES,
out_channels=cfg.MODEL.FPN.OUT_CHANNELS,
norm=cfg.MODEL.FPN.NORM,
top_block=top_block,
)
return backbone
@RPN_HEAD_REGISTRY.register()
class FBNetV2RpnHead(nn.Module):
def __init__(self, cfg, input_shape: List[ShapeSpec]):
super(FBNetV2RpnHead, self).__init__()
in_channels = [x.channels for x in input_shape]
assert len(set(in_channels)) == 1
in_channels = in_channels[0]
anchor_generator = build_anchor_generator(cfg, input_shape)
num_cell_anchors = anchor_generator.num_cell_anchors
box_dim = anchor_generator.box_dim
assert len(set(num_cell_anchors)) == 1
num_cell_anchors = num_cell_anchors[0]
self.rpn_feature, shape_specs = build_fbnet(
cfg,
name="rpn",
in_channels=in_channels
)
self.rpn_regressor = RPNHeadConvRegressor(
in_channels=shape_specs[-1].channels,
num_anchors=num_cell_anchors,
box_dim=box_dim,
)
def forward(self, x: List[torch.Tensor]):
x = [self.rpn_feature(y) for y in x]
return self.rpn_regressor(x)
@box_head.ROI_BOX_HEAD_REGISTRY.register()
class FBNetV2RoIBoxHead(nn.Module):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIBoxHead, self).__init__()
self.roi_box_conv, shape_specs = build_fbnet(
cfg,
name="bbox",
in_channels=input_shape.channels
)
self._out_channels = shape_specs[-1].channels
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.roi_box_conv(x)
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
x = self.avgpool(x)
return x
@property
@torch.jit.unused
def output_shape(self):
return ShapeSpec(channels=self._out_channels)
@keypoint_head.ROI_KEYPOINT_HEAD_REGISTRY.register()
class FBNetV2RoIKeypointHead(keypoint_head.BaseKeypointRCNNHead):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIKeypointHead, self).__init__(
cfg=cfg,
input_shape=input_shape,
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="kpts",
in_channels=input_shape.channels
)
self.predictor = KeypointRCNNPredictor(
in_channels=shape_specs[-1].channels,
num_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS,
)
def layers(self, x):
x = self.feature_extractor(x)
x = self.predictor(x)
return x
@keypoint_head.ROI_KEYPOINT_HEAD_REGISTRY.register()
class FBNetV2RoIKeypointHeadKRCNNPredictorNoUpscale(keypoint_head.BaseKeypointRCNNHead):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIKeypointHeadKRCNNPredictorNoUpscale, self).__init__(
cfg=cfg,
input_shape=input_shape,
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="kpts",
in_channels=input_shape.channels,
)
self.predictor = KeypointRCNNPredictorNoUpscale(
in_channels=shape_specs[-1].channels,
num_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS,
)
def layers(self, x):
x = self.feature_extractor(x)
x = self.predictor(x)
return x
@keypoint_head.ROI_KEYPOINT_HEAD_REGISTRY.register()
class FBNetV2RoIKeypointHeadKPRCNNIRFPredictorNoUpscale(
keypoint_head.BaseKeypointRCNNHead,
):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIKeypointHeadKPRCNNIRFPredictorNoUpscale, self).__init__(
cfg=cfg,
input_shape=input_shape,
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="kpts",
in_channels=input_shape.channels,
)
self.predictor = KeypointRCNNIRFPredictorNoUpscale(
cfg,
in_channels=shape_specs[-1].channels,
num_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS,
)
def layers(self, x):
x = self.feature_extractor(x)
x = self.predictor(x)
return x
@keypoint_head.ROI_KEYPOINT_HEAD_REGISTRY.register()
class FBNetV2RoIKeypointHeadKPRCNNConvUpsamplePredictorNoUpscale(
keypoint_head.BaseKeypointRCNNHead,
):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIKeypointHeadKPRCNNConvUpsamplePredictorNoUpscale, self).__init__(
cfg=cfg,
input_shape=input_shape,
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="kpts",
in_channels=input_shape.channels,
)
self.predictor = KeypointRCNNConvUpsamplePredictorNoUpscale(
cfg,
in_channels=shape_specs[-1].channels,
num_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS,
)
def layers(self, x):
x = self.feature_extractor(x)
x = self.predictor(x)
return x
@mask_head.ROI_MASK_HEAD_REGISTRY.register()
class FBNetV2RoIMaskHead(mask_head.BaseMaskRCNNHead):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIMaskHead, self).__init__(
cfg=cfg,
input_shape=input_shape,
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="mask",
in_channels=input_shape.channels,
)
num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
self.predictor = MaskRCNNConv1x1Predictor(shape_specs[-1].channels, num_classes)
def layers(self, x):
x = self.feature_extractor(x)
x = self.predictor(x)
return x
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import torch
import torch.nn as nn
from detectron2 import layers
from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock
class RPNHeadConvRegressor(nn.Module):
"""
A simple RPN Head for classification and bbox regression
"""
def __init__(self, in_channels, num_anchors, box_dim=4):
"""
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
box_dim (int): dimension of bbox
"""
super(RPNHeadConvRegressor, self).__init__()
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * box_dim, kernel_size=1, stride=1
)
for l in [self.cls_logits, self.bbox_pred]:
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
def forward(self, x: List[torch.Tensor]):
assert isinstance(x, (list, tuple))
logits = [self.cls_logits(y) for y in x]
bbox_reg = [self.bbox_pred(y) for y in x]
return logits, bbox_reg
class MaskRCNNConv1x1Predictor(nn.Module):
def __init__(self, in_channels, out_channels):
super(MaskRCNNConv1x1Predictor, self).__init__()
num_classes = out_channels
num_inputs = in_channels
self.mask_fcn_logits = nn.Conv2d(num_inputs, num_classes, 1, 1, 0)
for name, param in self.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
# Caffe2 implementation uses MSRAFill, which in fact
# corresponds to kaiming_normal_ in PyTorch
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
def forward(self, x):
return self.mask_fcn_logits(x)
class KeypointRCNNPredictor(nn.Module):
def __init__(self, in_channels, num_keypoints):
super(KeypointRCNNPredictor, self).__init__()
input_features = in_channels
deconv_kernel = 4
self.kps_score_lowres = nn.ConvTranspose2d(
input_features,
num_keypoints,
deconv_kernel,
stride=2,
padding=deconv_kernel // 2 - 1,
)
nn.init.kaiming_normal_(
self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
)
nn.init.constant_(self.kps_score_lowres.bias, 0)
self.up_scale = 2
self.out_channels = num_keypoints
def forward(self, x):
x = self.kps_score_lowres(x)
x = layers.interpolate(
x, scale_factor=self.up_scale, mode="bilinear", align_corners=False
)
return x
class KeypointRCNNPredictorNoUpscale(nn.Module):
def __init__(self, in_channels, num_keypoints):
super(KeypointRCNNPredictorNoUpscale, self).__init__()
input_features = in_channels
deconv_kernel = 4
self.kps_score_lowres = nn.ConvTranspose2d(
input_features,
num_keypoints,
deconv_kernel,
stride=2,
padding=deconv_kernel // 2 - 1,
)
nn.init.kaiming_normal_(
self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
)
nn.init.constant_(self.kps_score_lowres.bias, 0)
self.out_channels = num_keypoints
def forward(self, x):
x = self.kps_score_lowres(x)
return x
class KeypointRCNNIRFPredictorNoUpscale(nn.Module):
def __init__(self, cfg, in_channels, num_keypoints):
super(KeypointRCNNIRFPredictorNoUpscale, self).__init__()
input_features = in_channels
self.kps_score_lowres = IRFBlock(
input_features,
num_keypoints,
stride=-2,
expansion=3,
bn_args="none",
dw_skip_bnrelu=True,
)
self.out_channels = num_keypoints
def forward(self, x):
x = self.kps_score_lowres(x)
return x
class KeypointRCNNConvUpsamplePredictorNoUpscale(nn.Module):
def __init__(self, cfg, in_channels, num_keypoints):
super(KeypointRCNNConvUpsamplePredictorNoUpscale, self).__init__()
input_features = in_channels
self.kps_score_lowres = nn.Conv2d(
input_features,
num_keypoints,
kernel_size=3,
stride=1,
padding=1,
)
self.out_channels = num_keypoints
def forward(self, x):
x = layers.interpolate(
x, scale_factor=(2, 2), mode="nearest", align_corners=False
)
x = self.kps_score_lowres(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment