Commit 9051f71a authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Simplify Lightning task and model creation

Summary:
Given that the way to create D2 (https://github.com/facebookresearch/d2go/commit/465cdb842513eb910aa20fcedea1d2edd15dc7b7)go runner and Lightning task are different, get_class was introduced so that in application we could do:
```
if is Lightning:
    task_cls = get_class(classname)
    task = task_cls(cfg)
else:
    runner = create_runner(classname)
```
It turns out that we could need to do that in many places: workflow, binaries.
This diff revert `get_class` and return class in `create_runner` if the class is a Lightning module.

Reviewed By: newstzpz

Differential Revision: D26676595

fbshipit-source-id: c3ce2016d09fe073af4c2dd9f98eea4e59ca621b
parent 498cd31b
......@@ -3,19 +3,24 @@
import importlib
from typing import Type
from typing import Type, Union
from pytorch_lightning import LightningModule
from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
def get_class(class_full_name: str) -> Type:
"""Imports and returns the task class."""
def create_runner(
class_full_name: str, *args, **kwargs
) -> Union[BaseRunner, Type[LightningModule]]:
"""Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module.
"""
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
return runner_class
def create_runner(class_full_name: str, *args, **kwargs) -> BaseRunner:
"""Constructs a runner instance of the given class."""
runner_class = get_class(class_full_name)
if issubclass(runner_class, LightningModule):
# Return runner class for Lightning module since it requires config
# to construct
return runner_class
return runner_class(*args, **kwargs)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
from copy import deepcopy
from enum import Enum
......@@ -9,6 +10,14 @@ from typing import Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl
import torch
from d2go.config import CfgNode
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset,
)
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.runner.default_runner import (
Detectron2GoRunner,
GeneralizedRCNNRunner,
......@@ -25,6 +34,9 @@ from pytorch_lightning.utilities import rank_zero_info
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _is_lightning_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this checkpoint to be a Lightning checkpoint. """
......@@ -73,8 +85,9 @@ class ModelTag(str, Enum):
class DefaultTask(pl.LightningModule):
def __init__(self, cfg: CfgNode):
super().__init__()
self.register(cfg)
self.cfg = cfg
self.model = build_model(cfg)
self.model = self._build_model()
self.storage = None
# evaluators for validation datasets, split by model tag(default, ema),
# in the order of DATASETS.TEST
......@@ -94,6 +107,64 @@ class DefaultTask(pl.LightningModule):
def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)
def register(self, cfg: CfgNode):
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
def _build_model(self):
model = build_model(self.cfg)
if self.cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, self.cfg.MODEL.FROZEN_LAYER_REG_EXP, value=False)
return model
@classmethod
def from_config(cls, cfg: CfgNode, eval_only=False):
"""Builds Lightning module including model from config.
To load weights from a pretrained checkpoint, please specify checkpoint
path in `MODEL.WEIGHTS`.
Args:
cfg: D2go config node.
eval_only: True if module should be in eval mode.
"""
if eval_only and not cfg.MODEL.WEIGHTS:
logger.warning("MODEL.WEIGHTS is missing for eval only mode.")
if cfg.MODEL.WEIGHTS:
# only load model weights from checkpoint
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
task = cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
else:
task = cls(cfg)
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
assert task.ema_state, "EMA state is not loaded from checkpoint."
task.ema_state.apply_to(task.model)
if eval_only:
task.eval()
return task
@classmethod
def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config.
NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases.
Args:
cfg: D2go config node.
eval_only: True if model should be in eval mode.
"""
return cls.from_config(cfg, eval_only).model
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
......@@ -224,6 +295,13 @@ class DefaultTask(pl.LightningModule):
def forward(self, input):
return self.model(input)
@staticmethod
def _initialize(cfg: CfgNode):
pass
# ---------------------------------------------------------------------------
# Hooks
# ---------------------------------------------------------------------------
def on_pretrain_routine_end(self) -> None:
if self.cfg.MODEL_EMA.ENABLED:
if self.ema_state and self.ema_state.has_inited():
......
......@@ -9,6 +9,7 @@ import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.runner import create_runner
from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
......@@ -160,6 +161,17 @@ class TestDefaultRunner(unittest.TestCase):
ds_evaluators = runner.get_evaluator(cfg, ds_name, tmp_dir)
self.assertTrue(isinstance(ds_evaluators._evaluators[0], evaluator))
def test_create_runner(self):
runner = create_runner(
".".join(
[
default_runner.Detectron2GoRunner.__module__,
default_runner.Detectron2GoRunner.__name__,
]
)
)
self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner))
@helper.enable_ddp_env
def test_d2go_runner_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir:
......
......@@ -3,16 +3,17 @@
import os
import tempfile
import unittest
from copy import deepcopy
from typing import Dict
import pytorch_lightning as pl # type: ignore
import torch
from d2go.config import CfgNode
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
......@@ -35,91 +36,142 @@ class TestLightningTask(unittest.TestCase):
return False
return True
def test_load_from_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
params = {
"max_steps": 1,
"limit_train_batches": 1,
"num_sanity_val_steps": 0,
"callbacks": [checkpoint_callback],
}
trainer = pl.Trainer(**params)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
ckpt_path = os.path.join(tmp_dir, "test.ckpt")
trainer.save_checkpoint(ckpt_path)
self.assertTrue(os.path.exists(ckpt_path))
# load model weights from checkpoint
task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
self.assertTrue(
self._compare_state_dict(
task.model.state_dict(), task2.model.state_dict()
)
)
def test_train_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir:
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
cfg.MODEL_EMA.DECAY = 0.7
task = GeneralizedRCNNTask(cfg)
init_state = deepcopy(task.model.state_dict())
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
for k, v in task.model.state_dict().items():
init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)
@tempdir
def test_load_from_checkpoint(self, tmp_dir) -> None:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
params = {
"max_steps": 1,
"limit_train_batches": 1,
"num_sanity_val_steps": 0,
"checkpoint_callback": checkpoint_callback,
}
trainer = pl.Trainer(**params)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
ckpt_path = os.path.join(tmp_dir, "test.ckpt")
trainer.save_checkpoint(ckpt_path)
self.assertTrue(os.path.exists(ckpt_path))
# load model weights from checkpoint
task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
self.assertTrue(
self._compare_state_dict(init_state, task.ema_state.state_dict())
)
def test_load_ema_weights(self):
with tempfile.TemporaryDirectory() as tmp_dir:
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
self._compare_state_dict(
task.model.state_dict(), task2.model.state_dict()
)
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
@tempdir
def test_train_ema(self, tmp_dir):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
cfg.MODEL_EMA.DECAY = 0.7
task = GeneralizedRCNNTask(cfg)
init_state = deepcopy(task.model.state_dict())
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
for k, v in task.model.state_dict().items():
init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)
self.assertTrue(
self._compare_state_dict(init_state, task.ema_state.state_dict())
)
@tempdir
def test_load_ema_weights(self, tmp_dir):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
# load EMA weights from checkpoint
task2 = GeneralizedRCNNTask.load_from_checkpoint(
os.path.join(tmp_dir, "last.ckpt")
)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.ema_state.state_dict()
)
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
# load EMA weights from checkpoint
task2 = GeneralizedRCNNTask.load_from_checkpoint(
os.path.join(tmp_dir, "last.ckpt")
# apply EMA weights to model
task2.ema_state.apply_to(task2.model)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.model.state_dict()
)
)
def test_create_runner(self):
task_cls = create_runner(
f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
)
self.assertTrue(task_cls == GeneralizedRCNNTask)
@tempdir
def test_build_model(self, tmp_dir):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
# test building untrained model
model = GeneralizedRCNNTask.build_model(cfg)
self.assertTrue(model.training)
# test loading regular weights
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertFalse(model.training)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.ema_state.state_dict()
)
self._compare_state_dict(model.state_dict(), task.model.state_dict())
)
# apply EMA weights to model
task2.ema_state.apply_to(task2.model)
# test loading EMA weights
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertFalse(model.training)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.model.state_dict()
model.state_dict(), task.ema_state.state_dict()
)
)
......@@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost
from d2go.runner import get_class
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser
......@@ -80,26 +80,6 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks
def build_task(
cfg: CfgNode, task_cls: Type[GeneralizedRCNNTask]
) -> GeneralizedRCNNTask:
"""Builds instance of Lightning module based on the config and task class
name. To build a pre-trained model, specify the `MODEL.WEIGHTS` in the
config.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
task_cls: Lightning module class name.
Returns:
A instance of the given Lightning module.
"""
if cfg.MODEL.WEIGHTS:
# only load model weights from checkpoint
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
return task_cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
return task_cls(cfg)
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
......@@ -169,7 +149,7 @@ def main(
maybe_override_output_dir(cfg, output_dir)
task = build_task(cfg, task_cls)
task = task_cls.from_config(cfg, eval_only)
tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR)
trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make
......@@ -239,7 +219,7 @@ def argument_parser():
if __name__ == "__main__":
args = argument_parser().parse_args()
task_cls = get_class(args.runner) if args.runner else GeneralizedRCNNTask
task_cls = create_runner(args.runner) if args.runner else GeneralizedRCNNTask
cfg = build_config(args.config_file, task_cls, args.opts)
ret = main(
cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment