Commit 318a3d79 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

restructure lightning related code

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/237

Reviewed By: tglik

Differential Revision: D35954531

fbshipit-source-id: b69c8065928fe385d29f20f2c2460d60d63fca00
parent f3fc01aa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Trainer APIs on which D2Go's binary can build on top.
"""
# TODO: placeholder for now
def do_train():
pass
def do_test():
pass
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from typing import Dict
import pytorch_lightning as pl
from d2go.config import CfgNode, temp_defrost
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}"
def _do_train(
cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask
) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
Returns:
A map of model name to trained model config path.
"""
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
trainer.save_checkpoint(final_ckpt) # for validation monitor
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = final_ckpt
model_configs = dump_trained_model_configs(
cfg.OUTPUT_DIR, {"model_final": trained_cfg}
)
return model_configs
def _do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
"""Runs the evaluation with a pre-trained model.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
"""
with EventStorage() as storage:
task.storage = storage
trainer.test(task)
......@@ -9,13 +9,12 @@ from typing import Any, Dict, List, Optional, Type
import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost
from d2go.config import CfgNode
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser, setup_after_launch
from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
......@@ -90,48 +89,6 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
}
def do_train(
cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask
) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
Returns:
A map of model name to trained model config path.
"""
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
trainer.save_checkpoint(final_ckpt) # for validation monitor
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = final_ckpt
model_configs = dump_trained_model_configs(
cfg.OUTPUT_DIR, {"model_final": trained_cfg}
)
return model_configs
def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
"""Runs the evaluation with a pre-trained model.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
"""
with EventStorage() as storage:
task.storage = storage
trainer.test(task)
def main(
cfg: CfgNode,
output_dir: str,
......@@ -159,9 +116,9 @@ def main(
trainer = pl.Trainer(**trainer_params)
model_configs = None
if eval_only:
do_test(trainer, task)
_do_test(trainer, task)
else:
model_configs = do_train(cfg, trainer, task)
model_configs = _do_train(cfg, trainer, task)
return TrainOutput(
output_dir=cfg.OUTPUT_DIR,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment