"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d4bf5a8524820d3a232f7fc8e349d5e7d0d2880d"
Commit 857195d8 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Split lightning_train_net into OSS and internal

Summary:
As titled. The OSS version only use PyTorch Lightning while internal version leverages some features(e.g. Manifold integration, every_n_step checkpointing).
This diff splits train_net.main into smaller functions so that they could be shared across OSS and internal versions.

Reviewed By: zhanghang1989

Differential Revision: D26752701

fbshipit-source-id: 7f68e2a81e78193e117517a0ff668ab14b76ea65
parent 5d8068d8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import tempfile
import unittest
import numpy as np
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah
from d2go.tests.helper import tempdir
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
class TestLightningTrainNet(unittest.TestCase):
def _get_cfg(self, tmp_dir) -> CfgNode:
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir
def test_train_net_main(self, root_dir):
""" tests the main training entry point. """
cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
main(cfg, accelerator=None)
@tempdir
def test_checkpointing(self, tmp_dir):
""" tests saving and loading from checkpoint. """
cfg = self._get_cfg(tmp_dir)
out = main(cfg, accelerator=None)
ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
self.assertCountEqual(
[
"last.ckpt",
FINAL_MODEL_CKPT,
],
ckpts,
)
with tempfile.TemporaryDirectory() as tmp_dir2:
cfg2 = cfg.clone()
cfg2.defrost()
cfg2.OUTPUT_DIR = tmp_dir2
# load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
out2 = main(cfg2, accelerator=None, eval_only=True)
accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy:
np.testing.assert_equal(accuracy[k], accuracy2[k])
...@@ -10,19 +10,16 @@ from typing import Any, Dict, List, Optional, Type ...@@ -10,19 +10,16 @@ from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.runner import get_class from d2go.runner import get_class
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser from d2go.setup import basic_argument_parser
from d2go.utils.misc import dump_trained_model_configs from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint
from stl.lightning.callbacks.quantization import QuantizationAwareTraining
from stl.lightning.io.filesystem import get_filesystem
from stl.lightning.loggers import ManifoldTensorBoardLogger
from stl.lightning.utilities.manifold import manifold_uri_to_bucket_and_path
from torch.distributed import get_rank from torch.distributed import get_rank
...@@ -40,15 +37,8 @@ class TrainOutput: ...@@ -40,15 +37,8 @@ class TrainOutput:
model_configs: Optional[Dict[str, str]] = None model_configs: Optional[Dict[str, str]] = None
def get_tb_logger(output_dir: str) -> TensorBoardLogger:
"""Stores tensorboard outputs in output_dir."""
if output_dir.startswith("manifold://"):
bucket, path = manifold_uri_to_bucket_and_path(output_dir)
return ManifoldTensorBoardLogger(manifold_bucket=bucket, manifold_path=path)
return TensorBoardLogger(save_dir=output_dir)
def maybe_override_output_dir(cfg: CfgNode, output_dir: Optional[str]) -> None: def maybe_override_output_dir(cfg: CfgNode, output_dir: Optional[str]) -> None:
"""Overrides the output directory if `output_dir` is not None. """
if output_dir is not None and output_dir != cfg.OUTPUT_DIR: if output_dir is not None and output_dir != cfg.OUTPUT_DIR:
cfg.OUTPUT_DIR = output_dir cfg.OUTPUT_DIR = output_dir
logger.warning( logger.warning(
...@@ -69,12 +59,7 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -69,12 +59,7 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
callbacks: List[Callback] = [ callbacks: List[Callback] = [
LearningRateMonitor(logging_interval="step"), LearningRateMonitor(logging_interval="step"),
ModelCheckpoint( ModelCheckpoint(
directory=cfg.OUTPUT_DIR, dirpath=cfg.OUTPUT_DIR,
has_user_data=False,
save_top_k=-1,
every_n_epochs=-1,
every_n_steps=cfg.SOLVER.CHECKPOINT_PERIOD,
file_name_template="{step}",
save_last=True, save_last=True,
), ),
] ]
...@@ -95,6 +80,67 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -95,6 +80,67 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks return callbacks
def build_task(
cfg: CfgNode, task_cls: Type[GeneralizedRCNNTask]
) -> GeneralizedRCNNTask:
"""Builds instance of Lightning module based on the config and task class
name. To build a pre-trained model, specify the `MODEL.WEIGHTS` in the
config.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
task_cls: Lightning module class name.
Returns:
A instance of the given Lightning module.
"""
if cfg.MODEL.WEIGHTS:
# only load model weights from checkpoint
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
return task_cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
return task_cls(cfg)
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
Returns:
A map of model name to trained model config path.
"""
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
trainer.save_checkpoint(final_ckpt) # for validation monitor
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = final_ckpt
model_configs = dump_trained_model_configs(
cfg.OUTPUT_DIR, {"model_final": trained_cfg}
)
return model_configs
def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
"""Runs the evaluation with a pre-trained model.
Args:
cfg: The normalized ConfigNode for this D2Go Task.
trainer: PyTorch Lightning trainer.
task: Lightning module instance.
"""
with EventStorage() as storage:
task.storage = storage
trainer.test(task)
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
...@@ -123,14 +169,8 @@ def main( ...@@ -123,14 +169,8 @@ def main(
maybe_override_output_dir(cfg, output_dir) maybe_override_output_dir(cfg, output_dir)
if cfg.MODEL.WEIGHTS: task = build_task(cfg, task_cls)
# only load model weights from checkpoint tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR)
task = task_cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
logger.info(f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
else:
task = task_cls(cfg)
tb_logger = get_tb_logger(cfg.OUTPUT_DIR)
trainer_params = { trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make # training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first # sure max_steps is met first
...@@ -150,43 +190,23 @@ def main( ...@@ -150,43 +190,23 @@ def main(
} }
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
if get_filesystem(cfg.OUTPUT_DIR).exists(last_checkpoint): if PathManager.exists(last_checkpoint):
# resume training from checkpoint # resume training from checkpoint
trainer_params["resume_from_checkpoint"] = last_checkpoint trainer_params["resume_from_checkpoint"] = last_checkpoint
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")
# pyre-fixme[16]: Module `pl` has no attribute `Trainer`.
trainer = pl.Trainer(**trainer_params) trainer = pl.Trainer(**trainer_params)
# TODO: find a better place for event storage
with EventStorage() as storage:
task.storage = storage
model_configs = None model_configs = None
if eval_only: if eval_only:
logger.info( do_test(trainer, task)
f"start to evaluate with {num_machines} nodes and {num_gpus} GPUs"
)
trainer.test(task)
else: else:
logger.info(f"start to train with {num_machines} nodes and {num_gpus} GPUs") model_configs = do_train(cfg, trainer, task)
trainer.fit(task)
final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
trainer.save_checkpoint(final_ckpt) # for validation monitor
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = final_ckpt
model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, {"model_final": trained_cfg})
tb_log_dir = (
tb_logger.output_dir
if isinstance(tb_logger, ManifoldTensorBoardLogger)
else tb_logger.log_dir
)
return TrainOutput( return TrainOutput(
output_dir=cfg.OUTPUT_DIR, output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_log_dir, tensorboard_log_dir=tb_logger.log_dir,
accuracy=task.eval_res, accuracy=task.eval_res,
model_configs=model_configs model_configs=model_configs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment