Commit dc5ed4f5 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

added basic modeling hook.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/205

Added basic modeling hook.

Reviewed By: tglik

Differential Revision: D35535213

fbshipit-source-id: 662b08a905dd45f09737ca9c2d275b0324bcc134
parent fd317950
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import abstractmethod
import torch
class ModelingHook(object):
"""Modeling hooks provide a way to modify the model during the model building
process. It is simple but allows users to modify the model by creating wrapper,
override member functions, adding additional components, and loss etc.. It
could be used to implement features such as QAT, model transformation for training,
distillation/semi-supervised learning, and customization for loading pre-trained
weights.
"""
def __init__(self, cfg):
self.cfg = cfg
@abstractmethod
def apply(self, model: torch.nn.Module) -> torch.nn.Module:
"""This function will called during the model building process to modify
the behavior of the input model.
The created model will be
model == create meta arch -> model_hook_1.apply(model) ->
model_hook_2.apply(model) -> ...
"""
pass
@abstractmethod
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
"""This function will be called when the users called model.get_exportable_model()
after training. The main use case of the function is to remove the changes
applied to the model in `apply`. The hooks will be called in reverse order
as follow:
model.get_exportable_model() == model_hook_N.unapply(model) ->
model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model)
"""
pass
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
import torch
from d2go.modeling.meta_arch import modeling_hook as mh
class TestArch(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * 2
# create a wrapper of the model that add 1 to the output
class Wrapper(torch.nn.Module):
def __init__(self, model: TestArch):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x) + 1
class PlusOneHook(mh.ModelingHook):
def __init__(self, cfg):
super().__init__(cfg)
def apply(self, model: torch.nn.Module) -> torch.nn.Module:
return Wrapper(model)
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
assert isinstance(model, Wrapper)
return model.model
class TestModelingHook(unittest.TestCase):
def test_modeling_hook_simple(self):
model = TestArch()
hook = PlusOneHook(None)
model_with_hook = hook.apply(model)
self.assertEqual(model_with_hook(2), 5)
original_model = hook.unapply(model_with_hook)
self.assertEqual(model, original_model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment