Commit 6e8e4256 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

refactor setup for lightning_train_net

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/242

Reviewed By: newstzpz

Differential Revision: D36297282

fbshipit-source-id: 8efb19b3186f6978283f4e17e0628b55c2ec816e
parent e1623106
...@@ -338,11 +338,6 @@ class DefaultTask(pl.LightningModule): ...@@ -338,11 +338,6 @@ class DefaultTask(pl.LightningModule):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Runner methods # Runner methods
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def setup(self, stage: str):
from d2go.setup import setup_after_lightning_launch
setup_after_lightning_launch(self.cfg, self.cfg.OUTPUT_DIR)
def register(self, cfg: CfgNode): def register(self, cfg: CfgNode):
inject_coco_datasets(cfg) inject_coco_datasets(cfg)
register_dynamic_datasets(cfg) register_dynamic_datasets(cfg)
......
...@@ -6,23 +6,24 @@ import argparse ...@@ -6,23 +6,24 @@ import argparse
import logging import logging
import os import os
import time import time
from typing import Optional
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.config import ( from d2go.config import (
CfgNode as CN, CfgNode,
auto_scale_world_size, auto_scale_world_size,
reroute_config_path, reroute_config_path,
temp_defrost, temp_defrost,
) )
from d2go.config.utils import get_diff_cfg from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import create_runner, GeneralizedRCNNRunner from d2go.runner import create_runner, BaseRunner
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger from detectron2.utils.logger import setup_logger as _setup_logger
from detectron2.utils.serialize import PicklableWrapper from detectron2.utils.serialize import PicklableWrapper
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
...@@ -118,7 +119,7 @@ def create_cfg_from_cli_args(args, default_cfg): ...@@ -118,7 +119,7 @@ def create_cfg_from_cli_args(args, default_cfg):
in the args. in the args.
""" """
_C = CN() _C = CfgNode()
_C.INPUT = default_cfg.INPUT _C.INPUT = default_cfg.INPUT
_C.DATASETS = default_cfg.DATASETS _C.DATASETS = default_cfg.DATASETS
_C.DATALOADER = default_cfg.DATALOADER _C.DATALOADER = default_cfg.DATALOADER
...@@ -129,7 +130,7 @@ def create_cfg_from_cli_args(args, default_cfg): ...@@ -129,7 +130,7 @@ def create_cfg_from_cli_args(args, default_cfg):
_C.TENSORBOARD = default_cfg.TENSORBOARD _C.TENSORBOARD = default_cfg.TENSORBOARD
# NOTE configs below might not be necessary, but must add to make code work # NOTE configs below might not be necessary, but must add to make code work
_C.MODEL = CN() _C.MODEL = CfgNode()
_C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE _C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE
_C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON _C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON
_C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON _C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON
...@@ -180,18 +181,7 @@ def prepare_for_launch(args): ...@@ -180,18 +181,7 @@ def prepare_for_launch(args):
return cfg, output_dir, runner return cfg, output_dir, runner
def _setup_after_launch(cfg: CN, output_dir: str, runner): def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
"""
Set things up after entering DDP, including
- creating working directory
- setting up logger
- logging environment
- initializing runner
"""
create_dir_on_global_main_process(output_dir)
comm.synchronize()
setup_loggers(output_dir)
cfg.freeze()
if cfg.OUTPUT_DIR != output_dir: if cfg.OUTPUT_DIR != output_dir:
with temp_defrost(cfg): with temp_defrost(cfg):
logger.warning( logger.warning(
...@@ -200,72 +190,91 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner): ...@@ -200,72 +190,91 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner):
) )
) )
cfg.OUTPUT_DIR = output_dir cfg.OUTPUT_DIR = output_dir
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
def setup_after_launch(cfg: CN, output_dir: str, runner): def setup_after_launch(
_setup_after_launch(cfg, output_dir, runner) cfg: CfgNode,
logger.info("Initializing runner ...") output_dir: str,
runner = initialize_runner(runner, cfg) runner: Optional[BaseRunner] = None,
_scale_world_size: bool = True, # HACK: temporarily allow lightning_train_net to by pass this.
):
"""
Binary-level setup after entering DDP, including
- creating working directory
- setting up logger
- logging environment
- printing and dumping config
- (optional) initializing runner
"""
log_info(cfg, runner) create_dir_on_global_main_process(output_dir)
dump_cfg( setup_loggers(output_dir)
get_diff_cfg(runner.get_default_cfg(), cfg), log_system_info()
os.path.join(output_dir, "diff_config.yaml"),
)
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
cfg.freeze()
maybe_override_output_dir(cfg, output_dir)
logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if runner:
logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner))
# save the diff config
if runner:
default_cfg = runner.get_default_cfg()
dump_cfg(
get_diff_cfg(default_cfg, cfg),
os.path.join(output_dir, "diff_config.yaml"),
)
else:
# TODO: support getting default_cfg without runner.
pass
def setup_after_lightning_launch(cfg: CN, output_dir: str): # scale the config after dumping so that dumped config files keep original world size
_setup_after_launch(cfg, output_dir, runner=None) if _scale_world_size:
log_info(cfg, runner=None) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
@run_once() def setup_logger(
def setup_loggers(output_dir, color=None): module_name: str,
output_dir: str,
abbrev_name: Optional[str] = None,
color: Optional[bool] = None,
) -> logging.Logger:
if not color: if not color:
color = get_launch_environment() == "local" color = get_launch_environment() == "local"
if not abbrev_name:
abbrev_name = module_name
d2_logger = setup_logger( logger = _setup_logger(
output_dir, output_dir,
distributed_rank=comm.get_rank(), distributed_rank=comm.get_rank(),
color=color, color=color,
name="detectron2", name=module_name,
abbrev_name="d2", abbrev_name=abbrev_name,
) )
fvcore_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="fvcore",
)
d2go_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="d2go",
abbrev_name="d2go",
)
mobile_cv_logger = setup_logger(
output_dir,
distributed_rank=comm.get_rank(),
color=color,
name="mobile_cv",
abbrev_name="mobile_cv",
)
# NOTE: all above loggers have FileHandler pointing to the same file as d2_logger.
# Those files are opened upon creation, but it seems fine in 'a' mode.
# NOTE: the root logger might has been configured by other applications, # NOTE: the root logger might has been configured by other applications,
# since this already sub-top level, just don't propagate to root. # since this already sub-top level, just don't propagate to root.
d2_logger.propagate = False logger.propagate = False
fvcore_logger.propagate = False
d2go_logger.propagate = False
mobile_cv_logger.propagate = False
return logger
def log_info(cfg, runner):
@run_once()
def setup_loggers(output_dir):
setup_logger("detectron2", output_dir, abbrev_name="d2")
setup_logger("fvcore", output_dir)
setup_logger("d2go", output_dir)
setup_logger("mobile_cv", output_dir)
# NOTE: all above loggers have FileHandler pointing to the same file as d2_logger.
# Those files are opened upon creation, but it seems fine in 'a' mode.
def log_system_info():
num_processes = get_num_processes_per_machine() num_processes = get_num_processes_per_machine()
logger.info( logger.info(
"Using {} processes per machine. Rank of current process: {}".format( "Using {} processes per machine. Rank of current process: {}".format(
...@@ -282,29 +291,24 @@ def log_info(cfg, runner): ...@@ -282,29 +291,24 @@ def log_info(cfg, runner):
print_fbcode_info() print_fbcode_info()
except ImportError: except ImportError:
pass pass
logger.info("Running with full config:\n{}".format(cfg))
logger.info("Running with runner: {}".format(runner))
def dump_cfg(cfg, path): def dump_cfg(cfg: CfgNode, path: str) -> None:
if comm.is_main_process(): if comm.is_main_process():
with PathManager.open(path, "w") as f: with PathManager.open(path, "w") as f:
f.write(cfg.dump()) f.write(cfg.dump())
logger.info("Full config saved to {}".format(path)) logger.info("Full config saved to {}".format(path))
def create_dir_on_local_main_process(dir): def create_dir_on_global_main_process(path: str) -> None:
if get_local_rank() == 0 and dir: if comm.get_rank() == 0 and path:
PathManager.mkdirs(dir) PathManager.mkdirs(path)
# Add a barrier to make sure the existance of the dir for non-master process
comm.synchronize()
def create_dir_on_global_main_process(dir):
if comm.get_rank() == 0 and dir:
PathManager.mkdirs(dir)
def initialize_runner(runner, cfg): def initialize_runner(runner: BaseRunner, cfg: CfgNode) -> BaseRunner:
runner = runner or GeneralizedRCNNRunner() assert runner is not None, "now always requires a runner instance"
runner._initialize(cfg) runner._initialize(cfg)
return runner return runner
......
...@@ -33,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -33,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process, # set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset # which doesn't inherit the temporary dataset
main(cfg) main(cfg, root_dir)
@tempdir @tempdir
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint.""" """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
out = main(cfg) out = main(cfg, tmp_dir)
ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")] ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT) expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
for ckpt in expected_ckpts: for ckpt in expected_ckpts:
...@@ -48,11 +48,11 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -48,11 +48,11 @@ class TestLightningTrainNet(unittest.TestCase):
cfg2 = cfg.clone() cfg2 = cfg.clone()
cfg2.defrost() cfg2.defrost()
cfg2.OUTPUT_DIR = os.path.join(tmp_dir, "output")
# load the last checkpoint from previous training # load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
out2 = main(cfg2, eval_only=True) output_dir = os.path.join(tmp_dir, "output")
out2 = main(cfg2, output_dir, eval_only=True)
accuracy = flatten_config_dict(out.accuracy) accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
......
...@@ -14,7 +14,7 @@ from d2go.runner.callbacks.quantization import ( ...@@ -14,7 +14,7 @@ from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining, QuantizationAwareTraining,
) )
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser from d2go.setup import basic_argument_parser, setup_after_launch
from d2go.utils.misc import dump_trained_model_configs from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
...@@ -39,16 +39,6 @@ class TrainOutput: ...@@ -39,16 +39,6 @@ class TrainOutput:
model_configs: Optional[Dict[str, str]] = None model_configs: Optional[Dict[str, str]] = None
def maybe_override_output_dir(cfg: CfgNode, output_dir: Optional[str]) -> None:
"""Overrides the output directory if `output_dir` is not None."""
if output_dir is not None and output_dir != cfg.OUTPUT_DIR:
cfg.OUTPUT_DIR = output_dir
logger.warning(
f"Override cfg.OUTPUT_DIR ({cfg.OUTPUT_DIR}) to be the same as "
f"output_dir {output_dir}"
)
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
"""Gets the trainer callbacks based on the given D2Go Config. """Gets the trainer callbacks based on the given D2Go Config.
...@@ -147,7 +137,7 @@ def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask): ...@@ -147,7 +137,7 @@ def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: Optional[str] = None, output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask, task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
eval_only: bool = False, eval_only: bool = False,
num_machines: int = 1, num_machines: int = 1,
...@@ -160,8 +150,9 @@ def main( ...@@ -160,8 +150,9 @@ def main(
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
auto_scale_world_size(cfg, num_machines * num_processes) # FIXME: make comm.get_world_size() work properly.
maybe_override_output_dir(cfg, output_dir) setup_after_launch(cfg, output_dir, _scale_world_size=False)
auto_scale_world_size(cfg, new_world_size=num_machines * num_processes)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg, num_machines, num_processes) trainer_params = get_trainer_params(cfg, num_machines, num_processes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment