Commit 1f313907 authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

modeling_hook_registry move

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/324

Fixes bug introduced in D37600026 (https://github.com/facebookresearch/d2go/commit/1a18ba3420e3c823accf731a11c0a91dc3babd85) -- forgot to fix imports after moving modelinghook to registry/builtin.py

Differential Revision: D37646330

fbshipit-source-id: cb763d65e7bbfd07eea6eff61727a42a6fcfbc88
parent 1a18ba34
...@@ -23,6 +23,7 @@ from d2go.modeling import modeling_hook as mh ...@@ -23,6 +23,7 @@ from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import ( from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY, DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY, DISTILLATION_HELPER_REGISTRY,
MODELING_HOOK_REGISTRY,
) )
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin
...@@ -225,7 +226,7 @@ class LabelDistillation(BaseDistillationAlgorithm): ...@@ -225,7 +226,7 @@ class LabelDistillation(BaseDistillationAlgorithm):
return super().forward(new_batched_inputs) return super().forward(new_batched_inputs)
@mh.MODELING_HOOK_REGISTRY.register() @MODELING_HOOK_REGISTRY.register()
class DistillationModelingHook(mh.ModelingHook): class DistillationModelingHook(mh.ModelingHook):
"""Wrapper hook that allows us to apply different distillation algorithms """Wrapper hook that allows us to apply different distillation algorithms
based on config based on config
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh from d2go.modeling import modeling_hook as mh
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model
from d2go.registry.builtin import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY, MODELING_HOOK_REGISTRY
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
...@@ -32,7 +32,7 @@ class PlusOneWrapper(torch.nn.Module): ...@@ -32,7 +32,7 @@ class PlusOneWrapper(torch.nn.Module):
return self.model(x) + 1 return self.model(x) + 1
@mh.MODELING_HOOK_REGISTRY.register() @MODELING_HOOK_REGISTRY.register()
class PlusOneHook(mh.ModelingHook): class PlusOneHook(mh.ModelingHook):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__(cfg) super().__init__(cfg)
...@@ -55,7 +55,7 @@ class TimesTwoWrapper(torch.nn.Module): ...@@ -55,7 +55,7 @@ class TimesTwoWrapper(torch.nn.Module):
return self.model(x) * 2 return self.model(x) * 2
@mh.MODELING_HOOK_REGISTRY.register() @MODELING_HOOK_REGISTRY.register()
class TimesTwoHook(mh.ModelingHook): class TimesTwoHook(mh.ModelingHook):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__(cfg) super().__init__(cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment