Commit 70f157c5 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

organize Lightning task functions

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/51

As titled. Group the D2 (https://github.com/facebookresearch/d2go/commit/788cf41206fcd761da1974747f0c2c6c671ce871)go runner methods together.

Reviewed By: zhanghang1989, wat3rBro

Differential Revision: D27777726

fbshipit-source-id: f300bce444a401b61ff2adfb45b0c640b1f14855
parent 5600f01e
......@@ -112,15 +112,6 @@ class DefaultTask(pl.LightningModule):
self.model_ema = deepcopy(self.model)
self.dataset_evaluators[ModelTag.EMA] = []
def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)
def register(self, cfg: CfgNode):
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
def _build_model(self):
model = build_model(self.cfg)
......@@ -157,30 +148,6 @@ class DefaultTask(pl.LightningModule):
task.eval()
return task
@classmethod
def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config. If model has been prepared
for quantization, the function returns the prepared model.
NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases.
Args:
cfg: D2go config node.
eval_only: True if model should be in eval mode.
"""
task = cls.from_config(cfg, eval_only)
if hasattr(task, PREPARED):
task = getattr(task, PREPARED)
return task.model
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
def training_step(self, batch, batch_idx):
loss_dict = self.forward(batch)
losses = sum(loss_dict.values())
......@@ -240,10 +207,6 @@ class DefaultTask(pl.LightningModule):
return [optim], [{"scheduler": lr_scheduler, "interval": "step"}]
@staticmethod
def build_detection_train_loader(cfg, *args, mapper=None, **kwargs):
return Detectron2GoRunner.build_detection_train_loader(cfg, *args, **kwargs)
def train_dataloader(self):
return self.build_detection_train_loader(self.cfg)
......@@ -289,16 +252,6 @@ class DefaultTask(pl.LightningModule):
dataset_evaluators.append(evaluator)
# TODO: add visualization evaluator
@staticmethod
def get_evaluator(cfg: CfgNode, dataset_name: str, output_folder: str):
return Detectron2GoRunner.get_evaluator(
cfg=cfg, dataset_name=dataset_name, output_folder=output_folder
)
@staticmethod
def build_detection_test_loader(cfg, dataset_name, mapper=None):
return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper)
def _evaluation_dataloader(self):
# TODO: Support subsample n images
assert len(self.cfg.DATASETS.TEST)
......@@ -319,10 +272,55 @@ class DefaultTask(pl.LightningModule):
def forward(self, input):
return self.model(input)
# ---------------------------------------------------------------------------
# Runner methods
# ---------------------------------------------------------------------------
def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)
def register(self, cfg: CfgNode):
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
@classmethod
def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config.
NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases.
Args:
cfg: D2go config node.
eval_only: True if model should be in eval mode.
"""
task = cls.from_config(cfg, eval_only)
if hasattr(task, PREPARED):
task = getattr(task, PREPARED)
return task.model
@classmethod
def get_default_cfg(cls):
return Detectron2GoRunner.get_default_cfg()
@staticmethod
def _initialize(cfg: CfgNode):
pass
@staticmethod
def get_evaluator(cfg: CfgNode, dataset_name: str, output_folder: str):
return Detectron2GoRunner.get_evaluator(
cfg=cfg, dataset_name=dataset_name, output_folder=output_folder
)
@staticmethod
def build_detection_train_loader(cfg, *args, mapper=None, **kwargs):
return Detectron2GoRunner.build_detection_train_loader(cfg, *args, **kwargs)
@staticmethod
def build_detection_test_loader(cfg, dataset_name, mapper=None):
return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper)
# ---------------------------------------------------------------------------
# Hooks
# ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment