Commit c74e23b0 authored by Miquel Jubert Hermoso's avatar Miquel Jubert Hermoso Committed by Facebook GitHub Bot
Browse files

Fix type signature of create_runner

Summary: The type signature of create_runner is not accurate. We expect lightning runners to follow DefaultTask. Also change setup.py to not import directly, which was causing circular dependencies together with the change.

Reviewed By: wat3rBro

Differential Revision: D32792069

fbshipit-source-id: 0fbb55eb269dd681dbc8df49d71c9635f56293b8
parent 9c877fd4
......@@ -5,14 +5,13 @@
import importlib
from typing import Type, Union, Optional
from pytorch_lightning import LightningModule
from .default_runner import (
BaseRunner,
Detectron2GoRunner,
GeneralizedRCNNRunner,
TRAINER_HOOKS_REGISTRY,
)
from .lightning_task import DefaultTask
__all__ = [
......@@ -26,7 +25,7 @@ __all__ = [
def create_runner(
class_full_name: Optional[str], *args, **kwargs
) -> Union[BaseRunner, Type[LightningModule]]:
) -> Union[BaseRunner, Type[DefaultTask]]:
"""Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module.
"""
......@@ -36,7 +35,7 @@ def create_runner(
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
if issubclass(runner_class, LightningModule):
if issubclass(runner_class, DefaultTask):
# Return runner class for Lightning module since it requires config
# to construct
return runner_class
......
......@@ -32,7 +32,6 @@ from d2go.runner.default_runner import (
GeneralizedRCNNRunner,
_get_tbx_writer,
)
from d2go.setup import setup_after_lightning_launch
from d2go.utils.ema_state import EMAState
from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import VisualizationEvaluator
......@@ -340,6 +339,8 @@ class DefaultTask(pl.LightningModule):
# Runner methods
# ---------------------------------------------------------------------------
def setup(self, stage: str):
from d2go.setup import setup_after_lightning_launch
setup_after_lightning_launch(self.cfg, self.cfg.OUTPUT_DIR)
def register(self, cfg: CfgNode):
......
......@@ -16,7 +16,7 @@ from d2go.config import (
temp_defrost,
)
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import GeneralizedRCNNRunner, create_runner
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment