Commit 64c467e2 authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Move distillation config to runner default configs

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/557

To avoid circular dependencies, move the function `add_distillation_configs` that defines the default config for a `runner` making use of distillation from `mobile-vision/d2go/d2go/modeling/distillation.py` to `mobile-vision/d2go/d2go/runner/config_defaults.py`.

Reviewed By: wat3rBro

Differential Revision: D46096374

fbshipit-source-id: eb85d91b5239e7ab10809a9bf84c869d05d32401
parent f5072d01
...@@ -37,25 +37,6 @@ logger = logging.getLogger(__name__) ...@@ -37,25 +37,6 @@ logger = logging.getLogger(__name__)
ModelOutput = Union[None, torch.Tensor, Iterable["ModelOutput"]] ModelOutput = Union[None, torch.Tensor, Iterable["ModelOutput"]]
def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config
The TEACHER.CONFIG field allows us to build a PyTorch model using an
existing config. We can build any model that is normally supported by
D2Go (e.g., FBNet) because we just use the same config
"""
_C.DISTILLATION = CN()
_C.DISTILLATION.ALGORITHM = "LabelDistillation"
_C.DISTILLATION.HELPER = "BaseDistillationHelper"
_C.DISTILLATION.TEACHER = CN()
_C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
_C.DISTILLATION.TEACHER.DEVICE = ""
_C.DISTILLATION.TEACHER.TYPE = "torchscript"
_C.DISTILLATION.TEACHER.CONFIG_FNAME = ""
_C.DISTILLATION.TEACHER.RUNNER_NAME = "d2go.runner.GeneralizedRCNNRunner"
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
@dataclass @dataclass
class LayerLossMetadata: class LayerLossMetadata:
loss: nn.Module loss: nn.Module
......
...@@ -10,7 +10,6 @@ from d2go.data.build import ( ...@@ -10,7 +10,6 @@ from d2go.data.build import (
from d2go.data.config import add_d2go_data_default_configs from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import ema, kmeans_anchors from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.distillation import add_distillation_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
...@@ -39,6 +38,25 @@ def _add_base_runner_default_fb_cfg(_C: CN) -> None: ...@@ -39,6 +38,25 @@ def _add_base_runner_default_fb_cfg(_C: CN) -> None:
pass pass
def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config
The TEACHER.CONFIG field allows us to build a PyTorch model using an
existing config. We can build any model that is normally supported by
D2Go (e.g., FBNet) because we just use the same config
"""
_C.DISTILLATION = CN()
_C.DISTILLATION.ALGORITHM = "LabelDistillation"
_C.DISTILLATION.HELPER = "BaseDistillationHelper"
_C.DISTILLATION.TEACHER = CN()
_C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
_C.DISTILLATION.TEACHER.DEVICE = ""
_C.DISTILLATION.TEACHER.TYPE = "torchscript"
_C.DISTILLATION.TEACHER.CONFIG_FNAME = ""
_C.DISTILLATION.TEACHER.RUNNER_NAME = "d2go.runner.GeneralizedRCNNRunner"
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
def _add_detectron2go_runner_default_cfg(_C: CN) -> None: def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# _C.MODEL.FBNET_V2... # _C.MODEL.FBNET_V2...
add_fbnet_v2_default_configs(_C) add_fbnet_v2_default_configs(_C)
......
...@@ -13,7 +13,6 @@ from d2go.modeling import modeling_hook as mh ...@@ -13,7 +13,6 @@ from d2go.modeling import modeling_hook as mh
from d2go.modeling.distillation import ( from d2go.modeling.distillation import (
_build_teacher, _build_teacher,
_set_device, _set_device,
add_distillation_configs,
BaseDistillationHelper, BaseDistillationHelper,
CachedLayer, CachedLayer,
compute_layer_losses, compute_layer_losses,
...@@ -38,6 +37,7 @@ from d2go.registry.builtin import ( ...@@ -38,6 +37,7 @@ from d2go.registry.builtin import (
DISTILLATION_HELPER_REGISTRY, DISTILLATION_HELPER_REGISTRY,
META_ARCH_REGISTRY, META_ARCH_REGISTRY,
) )
from d2go.runner.config_defaults import add_distillation_configs
from d2go.runner.default_runner import BaseRunner from d2go.runner.default_runner import BaseRunner
from d2go.utils.testing import helper from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment