Commit 9051f71a authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Simplify Lightning task and model creation

Summary:
Given that the way to create D2 (https://github.com/facebookresearch/d2go/commit/465cdb842513eb910aa20fcedea1d2edd15dc7b7)go runner and Lightning task are different, get_class was introduced so that in application we could do:
```
if is Lightning:
    task_cls = get_class(classname)
    task = task_cls(cfg)
else:
    runner = create_runner(classname)
```
It turns out that we could need to do that in many places: workflow, binaries.
This diff revert `get_class` and return class in `create_runner` if the class is a Lightning module.

Reviewed By: newstzpz

Differential Revision: D26676595

fbshipit-source-id: c3ce2016d09fe073af4c2dd9f98eea4e59ca621b
parent 498cd31b
...@@ -3,19 +3,24 @@ ...@@ -3,19 +3,24 @@
import importlib import importlib
from typing import Type from typing import Type, Union
from pytorch_lightning import LightningModule
from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
def get_class(class_full_name: str) -> Type:
"""Imports and returns the task class.""" def create_runner(
class_full_name: str, *args, **kwargs
) -> Union[BaseRunner, Type[LightningModule]]:
"""Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module.
"""
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1) runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name) runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name) runner_class = getattr(runner_module, runner_class_name)
if issubclass(runner_class, LightningModule):
# Return runner class for Lightning module since it requires config
# to construct
return runner_class return runner_class
def create_runner(class_full_name: str, *args, **kwargs) -> BaseRunner:
"""Constructs a runner instance of the given class."""
runner_class = get_class(class_full_name)
return runner_class(*args, **kwargs) return runner_class(*args, **kwargs)
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os import os
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
...@@ -9,6 +10,14 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -9,6 +10,14 @@ from typing import Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset,
)
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
Detectron2GoRunner, Detectron2GoRunner,
GeneralizedRCNNRunner, GeneralizedRCNNRunner,
...@@ -25,6 +34,9 @@ from pytorch_lightning.utilities import rank_zero_info ...@@ -25,6 +34,9 @@ from pytorch_lightning.utilities import rank_zero_info
_STATE_DICT_KEY = "state_dict" _STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model" _OLD_STATE_DICT_KEY = "model"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _is_lightning_checkpoint(checkpoint: Dict[str, Any]) -> bool: def _is_lightning_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this checkpoint to be a Lightning checkpoint. """ """ Returns true if we believe this checkpoint to be a Lightning checkpoint. """
...@@ -73,8 +85,9 @@ class ModelTag(str, Enum): ...@@ -73,8 +85,9 @@ class ModelTag(str, Enum):
class DefaultTask(pl.LightningModule): class DefaultTask(pl.LightningModule):
def __init__(self, cfg: CfgNode): def __init__(self, cfg: CfgNode):
super().__init__() super().__init__()
self.register(cfg)
self.cfg = cfg self.cfg = cfg
self.model = build_model(cfg) self.model = self._build_model()
self.storage = None self.storage = None
# evaluators for validation datasets, split by model tag(default, ema), # evaluators for validation datasets, split by model tag(default, ema),
# in the order of DATASETS.TEST # in the order of DATASETS.TEST
...@@ -94,6 +107,64 @@ class DefaultTask(pl.LightningModule): ...@@ -94,6 +107,64 @@ class DefaultTask(pl.LightningModule):
def setup(self, stage: str): def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None) setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)
def register(self, cfg: CfgNode):
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
def _build_model(self):
model = build_model(self.cfg)
if self.cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, self.cfg.MODEL.FROZEN_LAYER_REG_EXP, value=False)
return model
@classmethod
def from_config(cls, cfg: CfgNode, eval_only=False):
"""Builds Lightning module including model from config.
To load weights from a pretrained checkpoint, please specify checkpoint
path in `MODEL.WEIGHTS`.
Args:
cfg: D2go config node.
eval_only: True if module should be in eval mode.
"""
if eval_only and not cfg.MODEL.WEIGHTS:
logger.warning("MODEL.WEIGHTS is missing for eval only mode.")
if cfg.MODEL.WEIGHTS:
# only load model weights from checkpoint
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
task = cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
else:
task = cls(cfg)
if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
assert task.ema_state, "EMA state is not loaded from checkpoint."
task.ema_state.apply_to(task.model)
if eval_only:
task.eval()
return task
@classmethod
def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config.
NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases.
Args:
cfg: D2go config node.
eval_only: True if model should be in eval mode.
"""
return cls.from_config(cfg, eval_only).model
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
@classmethod @classmethod
def get_default_cfg(cls): def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg() return Detectron2GoRunner.get_default_cfg()
...@@ -224,6 +295,13 @@ class DefaultTask(pl.LightningModule): ...@@ -224,6 +295,13 @@ class DefaultTask(pl.LightningModule):
def forward(self, input): def forward(self, input):
return self.model(input) return self.model(input)
@staticmethod
def _initialize(cfg: CfgNode):
pass
# ---------------------------------------------------------------------------
# Hooks
# ---------------------------------------------------------------------------
def on_pretrain_routine_end(self) -> None: def on_pretrain_routine_end(self) -> None:
if self.cfg.MODEL_EMA.ENABLED: if self.cfg.MODEL_EMA.ENABLED:
if self.ema_state and self.ema_state.has_inited(): if self.ema_state and self.ema_state.has_inited():
......
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.runner import create_runner
from d2go.utils.testing import helper from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
...@@ -160,6 +161,17 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -160,6 +161,17 @@ class TestDefaultRunner(unittest.TestCase):
ds_evaluators = runner.get_evaluator(cfg, ds_name, tmp_dir) ds_evaluators = runner.get_evaluator(cfg, ds_name, tmp_dir)
self.assertTrue(isinstance(ds_evaluators._evaluators[0], evaluator)) self.assertTrue(isinstance(ds_evaluators._evaluators[0], evaluator))
def test_create_runner(self):
runner = create_runner(
".".join(
[
default_runner.Detectron2GoRunner.__module__,
default_runner.Detectron2GoRunner.__name__,
]
)
)
self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner))
@helper.enable_ddp_env @helper.enable_ddp_env
def test_d2go_runner_ema(self): def test_d2go_runner_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
......
...@@ -3,16 +3,17 @@ ...@@ -3,16 +3,17 @@
import os import os
import tempfile
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from typing import Dict from typing import Dict
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
...@@ -35,8 +36,8 @@ class TestLightningTask(unittest.TestCase): ...@@ -35,8 +36,8 @@ class TestLightningTask(unittest.TestCase):
return False return False
return True return True
def test_load_from_checkpoint(self) -> None: @tempdir
with tempfile.TemporaryDirectory() as tmp_dir: def test_load_from_checkpoint(self, tmp_dir) -> None:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
...@@ -44,7 +45,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -44,7 +45,7 @@ class TestLightningTask(unittest.TestCase):
"max_steps": 1, "max_steps": 1,
"limit_train_batches": 1, "limit_train_batches": 1,
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"callbacks": [checkpoint_callback], "checkpoint_callback": checkpoint_callback,
} }
trainer = pl.Trainer(**params) trainer = pl.Trainer(**params)
with EventStorage() as storage: with EventStorage() as storage:
...@@ -62,8 +63,8 @@ class TestLightningTask(unittest.TestCase): ...@@ -62,8 +63,8 @@ class TestLightningTask(unittest.TestCase):
) )
) )
def test_train_ema(self): @tempdir
with tempfile.TemporaryDirectory() as tmp_dir: def test_train_ema(self, tmp_dir):
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
cfg.MODEL_EMA.DECAY = 0.7 cfg.MODEL_EMA.DECAY = 0.7
...@@ -86,8 +87,8 @@ class TestLightningTask(unittest.TestCase): ...@@ -86,8 +87,8 @@ class TestLightningTask(unittest.TestCase):
self._compare_state_dict(init_state, task.ema_state.state_dict()) self._compare_state_dict(init_state, task.ema_state.state_dict())
) )
def test_load_ema_weights(self): @tempdir
with tempfile.TemporaryDirectory() as tmp_dir: def test_load_ema_weights(self, tmp_dir):
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
...@@ -123,3 +124,54 @@ class TestLightningTask(unittest.TestCase): ...@@ -123,3 +124,54 @@ class TestLightningTask(unittest.TestCase):
task.ema_state.state_dict(), task2.model.state_dict() task.ema_state.state_dict(), task2.model.state_dict()
) )
) )
def test_create_runner(self):
task_cls = create_runner(
f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
)
self.assertTrue(task_cls == GeneralizedRCNNTask)
@tempdir
def test_build_model(self, tmp_dir):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
# test building untrained model
model = GeneralizedRCNNTask.build_model(cfg)
self.assertTrue(model.training)
# test loading regular weights
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertFalse(model.training)
self.assertTrue(
self._compare_state_dict(model.state_dict(), task.model.state_dict())
)
# test loading EMA weights
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertFalse(model.training)
self.assertTrue(
self._compare_state_dict(
model.state_dict(), task.ema_state.state_dict()
)
)
...@@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Type ...@@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.runner import get_class from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser from d2go.setup import basic_argument_parser
...@@ -80,26 +80,6 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -80,26 +80,6 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks return callbacks
def build_task(
cfg: CfgNode, task_cls: Type[GeneralizedRCNNTask]
) -> GeneralizedRCNNTask:
"""Builds instance of Lightning module based on the config and task class
name. To build a pre-trained model, specify the `MODEL.WEIGHTS` in the
config.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
task_cls: Lightning module class name.
Returns:
A instance of the given Lightning module.
"""
if cfg.MODEL.WEIGHTS:
# only load model weights from checkpoint
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
return task_cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
return task_cls(cfg)
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]: def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
"""Runs the training loop with given trainer and task. """Runs the training loop with given trainer and task.
...@@ -169,7 +149,7 @@ def main( ...@@ -169,7 +149,7 @@ def main(
maybe_override_output_dir(cfg, output_dir) maybe_override_output_dir(cfg, output_dir)
task = build_task(cfg, task_cls) task = task_cls.from_config(cfg, eval_only)
tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR) tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR)
trainer_params = { trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make # training loop is bounded by max steps, use a large max_epochs to make
...@@ -239,7 +219,7 @@ def argument_parser(): ...@@ -239,7 +219,7 @@ def argument_parser():
if __name__ == "__main__": if __name__ == "__main__":
args = argument_parser().parse_args() args = argument_parser().parse_args()
task_cls = get_class(args.runner) if args.runner else GeneralizedRCNNTask task_cls = create_runner(args.runner) if args.runner else GeneralizedRCNNTask
cfg = build_config(args.config_file, task_cls, args.opts) cfg = build_config(args.config_file, task_cls, args.opts)
ret = main( ret = main(
cfg, cfg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment