Commit 6e8e4256 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

refactor setup for lightning_train_net

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/242

Reviewed By: newstzpz

Differential Revision: D36297282

fbshipit-source-id: 8efb19b3186f6978283f4e17e0628b55c2ec816e
parent e1623106
......@@ -338,11 +338,6 @@ class DefaultTask(pl.LightningModule):
# ---------------------------------------------------------------------------
# Runner methods
# ---------------------------------------------------------------------------
def setup(self, stage: str):
from d2go.setup import setup_after_lightning_launch
setup_after_lightning_launch(self.cfg, self.cfg.OUTPUT_DIR)
def register(self, cfg: CfgNode):
inject_coco_datasets(cfg)
register_dynamic_datasets(cfg)
......
......@@ -6,23 +6,24 @@ import argparse
import logging
import os
import time
from typing import Optional
import detectron2.utils.comm as comm
import torch
from d2go.config import (
CfgNode as CN,
CfgNode,
auto_scale_world_size,
reroute_config_path,
temp_defrost,
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.runner import create_runner, BaseRunner
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from detectron2.utils.logger import setup_logger as _setup_logger
from detectron2.utils.serialize import PicklableWrapper
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
......@@ -118,7 +119,7 @@ def create_cfg_from_cli_args(args, default_cfg):
in the args.
"""
_C = CN()
_C = CfgNode()
_C.INPUT = default_cfg.INPUT
_C.DATASETS = default_cfg.DATASETS
_C.DATALOADER = default_cfg.DATALOADER
......@@ -129,7 +130,7 @@ def create_cfg_from_cli_args(args, default_cfg):
_C.TENSORBOARD = default_cfg.TENSORBOARD
# NOTE configs below might not be necessary, but must add to make code work
_C.MODEL = CN()
_C.MODEL = CfgNode()
_C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE
_C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON
_C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON
......@@ -180,18 +181,7 @@ def prepare_for_launch(args):
return cfg, output_dir, runner
def _setup_after_launch(cfg: CN, output_dir: str, runner):
"""
Set things up after entering DDP, including
- creating working directory
- setting up logger
- logging environment
- initializing runner
"""
create_dir_on_global_main_process(output_dir)
comm.synchronize()
setup_loggers(output_dir)
cfg.freeze()
def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
if cfg.OUTPUT_DIR != output_dir:
with temp_defrost(cfg):
logger.warning(
......@@ -200,72 +190,91 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner):
)
)
cfg.OUTPUT_DIR = output_dir
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
def setup_after_launch(cfg: CN, output_dir: str, runner):
_setup_after_launch(cfg, output_dir, runner)
def setup_after_launch(
cfg: CfgNode,
output_dir: str,
runner: Optional[BaseRunner] = None,
_scale_world_size: bool = True, # HACK: temporarily allow lightning_train_net to by pass this.
):
"""
Binary-level setup after entering DDP, including
- creating working directory
- setting up logger
- logging environment
- printing and dumping config
- (optional) initializing runner
"""
create_dir_on_global_main_process(output_dir)
setup_loggers(output_dir)
log_system_info()
cfg.freeze()
maybe_override_output_dir(cfg, output_dir)
logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if runner:
logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner))
log_info(cfg, runner)
# save the diff config
if runner:
default_cfg = runner.get_default_cfg()
dump_cfg(
get_diff_cfg(runner.get_default_cfg(), cfg),
get_diff_cfg(default_cfg, cfg),
os.path.join(output_dir, "diff_config.yaml"),
)
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
else:
# TODO: support getting default_cfg without runner.
pass
def setup_after_lightning_launch(cfg: CN, output_dir: str):
_setup_after_launch(cfg, output_dir, runner=None)
log_info(cfg, runner=None)
# scale the config after dumping so that dumped config files keep original world size
if _scale_world_size:
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
@run_once()
def setup_loggers(output_dir, color=None):
def setup_logger(
module_name: str,
output_dir: str,
abbrev_name: Optional[str] = None,
color: Optional[bool] = None,
) -> logging.Logger:
if not color:
color = get_launch_environment() == "local"
if not abbrev_name:
abbrev_name = module_name
d2_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="detectron2",
abbrev_name="d2",
)
fvcore_logger = setup_logger(
logger = _setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="fvcore",
)
d2go_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="d2go",
abbrev_name="d2go",
)
mobile_cv_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="mobile_cv",
abbrev_name="mobile_cv",
name=module_name,
abbrev_name=abbrev_name,
)
# NOTE: all above loggers have FileHandler pointing to the same file as d2_logger.
# Those files are opened upon creation, but it seems fine in 'a' mode.
# NOTE: the root logger might has been configured by other applications,
# since this already sub-top level, just don't propagate to root.
d2_logger.propagate = False
fvcore_logger.propagate = False
d2go_logger.propagate = False
mobile_cv_logger.propagate = False
logger.propagate = False
return logger
@run_once()
def setup_loggers(output_dir):
setup_logger("detectron2", output_dir, abbrev_name="d2")
setup_logger("fvcore", output_dir)
setup_logger("d2go", output_dir)
setup_logger("mobile_cv", output_dir)
# NOTE: all above loggers have FileHandler pointing to the same file as d2_logger.
# Those files are opened upon creation, but it seems fine in 'a' mode.
def log_info(cfg, runner):
def log_system_info():
num_processes = get_num_processes_per_machine()
logger.info(
"Using {} processes per machine. Rank of current process: {}".format(
......@@ -282,29 +291,24 @@ def log_info(cfg, runner):
print_fbcode_info()
except ImportError:
pass
logger.info("Running with full config:\n{}".format(cfg))
logger.info("Running with runner: {}".format(runner))
def dump_cfg(cfg, path):
def dump_cfg(cfg: CfgNode, path: str) -> None:
if comm.is_main_process():
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
logger.info("Full config saved to {}".format(path))
def create_dir_on_local_main_process(dir):
if get_local_rank() == 0 and dir:
PathManager.mkdirs(dir)
def create_dir_on_global_main_process(dir):
if comm.get_rank() == 0 and dir:
PathManager.mkdirs(dir)
def create_dir_on_global_main_process(path: str) -> None:
if comm.get_rank() == 0 and path:
PathManager.mkdirs(path)
# Add a barrier to make sure the existance of the dir for non-master process
comm.synchronize()
def initialize_runner(runner, cfg):
runner = runner or GeneralizedRCNNRunner()
def initialize_runner(runner: BaseRunner, cfg: CfgNode) -> BaseRunner:
assert runner is not None, "now always requires a runner instance"
runner._initialize(cfg)
return runner
......
......@@ -33,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
main(cfg)
main(cfg, root_dir)
@tempdir
def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir)
out = main(cfg)
out = main(cfg, tmp_dir)
ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
for ckpt in expected_ckpts:
......@@ -48,11 +48,11 @@ class TestLightningTrainNet(unittest.TestCase):
cfg2 = cfg.clone()
cfg2.defrost()
cfg2.OUTPUT_DIR = os.path.join(tmp_dir, "output")
# load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
out2 = main(cfg2, eval_only=True)
output_dir = os.path.join(tmp_dir, "output")
out2 = main(cfg2, output_dir, eval_only=True)
accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy:
......
......@@ -14,7 +14,7 @@ from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
)
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser
from d2go.setup import basic_argument_parser, setup_after_launch
from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage
from detectron2.utils.file_io import PathManager
......@@ -39,16 +39,6 @@ class TrainOutput:
model_configs: Optional[Dict[str, str]] = None
def maybe_override_output_dir(cfg: CfgNode, output_dir: Optional[str]) -> None:
"""Overrides the output directory if `output_dir` is not None."""
if output_dir is not None and output_dir != cfg.OUTPUT_DIR:
cfg.OUTPUT_DIR = output_dir
logger.warning(
f"Override cfg.OUTPUT_DIR ({cfg.OUTPUT_DIR}) to be the same as "
f"output_dir {output_dir}"
)
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
"""Gets the trainer callbacks based on the given D2Go Config.
......@@ -147,7 +137,7 @@ def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
def main(
cfg: CfgNode,
output_dir: Optional[str] = None,
output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
eval_only: bool = False,
num_machines: int = 1,
......@@ -160,8 +150,9 @@ def main(
num_processes: Number of processes on each node.
eval_only: True if run evaluation only.
"""
auto_scale_world_size(cfg, num_machines * num_processes)
maybe_override_output_dir(cfg, output_dir)
# FIXME: make comm.get_world_size() work properly.
setup_after_launch(cfg, output_dir, _scale_world_size=False)
auto_scale_world_size(cfg, new_world_size=num_machines * num_processes)
task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg, num_machines, num_processes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment