Commit 403a5321 authored by Miquel Jubert Hermoso's avatar Miquel Jubert Hermoso Committed by Facebook GitHub Bot
Browse files

Refactor registry and setup backend to separate file

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/233

*This diff is part of a stack which has the goal of "buckifying" D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go core and enabling autodeps and other tooling. The last diff in the stack introduces the TARGETS. The diffs earlier in the stack are resolving circular dependencies and other issues which prevent the buckification from occurring.*

Again, a circular dependency between quantization/modeling -> meta_arch->rcnn. Solved in a similar way, they are caused by functions which only interact with the global registry. Separate the cause of the circular dependency into its own file, which will later have a separate TARGET.

Reviewed By: wat3rBro

Differential Revision: D35966052

fbshipit-source-id: c118a2af989e6c387641fe055a6734f9f0ab1db5
parent 8d58b499
......@@ -40,6 +40,7 @@ class GeneralizedRCNNPatch:
"prepare_for_export",
"prepare_for_quant",
"prepare_for_quant_convert",
"_cast_model_to_device",
]
def prepare_for_export(self, cfg, *args, **kwargs):
......@@ -56,6 +57,9 @@ class GeneralizedRCNNPatch:
)
return func(self, cfg, *args, **kwargs)
def _cast_model_to_device(self, device):
return _cast_detection_model(self, device)
@RCNN_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
......
......@@ -139,11 +139,17 @@ def add_quantization_default_configs(_C):
# TODO: model.to(device) might not work for detection meta-arch, this function is the
# workaround, in general, we might need a meta-arch API for this if needed.
def _cast_model_to_device(model, device):
from d2go.modeling.meta_arch.rcnn import _cast_detection_model
from detectron2.modeling import GeneralizedRCNN
assert isinstance(model, GeneralizedRCNN), "Currently only availabe for RCNN"
return _cast_detection_model(model, device)
if hasattr(
model, "_cast_model_to_device"
): # we can make this formal by removing "_"
return model._cast_model_to_device(device)
else:
logger.warning(
"model.to(device) doesn't guarentee moving the entire model, "
"if customization is needed, please implement _cast_model_to_device "
"for the MetaArch"
)
return model.to(device)
def add_d2_quant_mapping(mappings):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment