"src/vscode:/vscode.git/clone" did not exist on "9b2c0a7dbe5487e700a0039a09c277d73a17ccc2"
Commit 64c467e2 authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Move distillation config to runner default configs

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/557

To avoid circular dependencies, move the function `add_distillation_configs` that defines the default config for a `runner` making use of distillation from `mobile-vision/d2go/d2go/modeling/distillation.py` to `mobile-vision/d2go/d2go/runner/config_defaults.py`.

Reviewed By: wat3rBro

Differential Revision: D46096374

fbshipit-source-id: eb85d91b5239e7ab10809a9bf84c869d05d32401
parent f5072d01
......@@ -37,25 +37,6 @@ logger = logging.getLogger(__name__)
ModelOutput = Union[None, torch.Tensor, Iterable["ModelOutput"]]
def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config
The TEACHER.CONFIG field allows us to build a PyTorch model using an
existing config. We can build any model that is normally supported by
D2Go (e.g., FBNet) because we just use the same config
"""
_C.DISTILLATION = CN()
_C.DISTILLATION.ALGORITHM = "LabelDistillation"
_C.DISTILLATION.HELPER = "BaseDistillationHelper"
_C.DISTILLATION.TEACHER = CN()
_C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
_C.DISTILLATION.TEACHER.DEVICE = ""
_C.DISTILLATION.TEACHER.TYPE = "torchscript"
_C.DISTILLATION.TEACHER.CONFIG_FNAME = ""
_C.DISTILLATION.TEACHER.RUNNER_NAME = "d2go.runner.GeneralizedRCNNRunner"
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
@dataclass
class LayerLossMetadata:
loss: nn.Module
......
......@@ -10,7 +10,6 @@ from d2go.data.build import (
from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.distillation import add_distillation_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs
......@@ -39,6 +38,25 @@ def _add_base_runner_default_fb_cfg(_C: CN) -> None:
pass
def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config
The TEACHER.CONFIG field allows us to build a PyTorch model using an
existing config. We can build any model that is normally supported by
D2Go (e.g., FBNet) because we just use the same config
"""
_C.DISTILLATION = CN()
_C.DISTILLATION.ALGORITHM = "LabelDistillation"
_C.DISTILLATION.HELPER = "BaseDistillationHelper"
_C.DISTILLATION.TEACHER = CN()
_C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
_C.DISTILLATION.TEACHER.DEVICE = ""
_C.DISTILLATION.TEACHER.TYPE = "torchscript"
_C.DISTILLATION.TEACHER.CONFIG_FNAME = ""
_C.DISTILLATION.TEACHER.RUNNER_NAME = "d2go.runner.GeneralizedRCNNRunner"
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# _C.MODEL.FBNET_V2...
add_fbnet_v2_default_configs(_C)
......
......@@ -13,7 +13,6 @@ from d2go.modeling import modeling_hook as mh
from d2go.modeling.distillation import (
_build_teacher,
_set_device,
add_distillation_configs,
BaseDistillationHelper,
CachedLayer,
compute_layer_losses,
......@@ -38,6 +37,7 @@ from d2go.registry.builtin import (
DISTILLATION_HELPER_REGISTRY,
META_ARCH_REGISTRY,
)
from d2go.runner.config_defaults import add_distillation_configs
from d2go.runner.default_runner import BaseRunner
from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment