Commit 8cf2b879 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support diff config for lightning_train_net

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/285

`setup_after_launch` can now take `DefaultTask` as well (the `runner_or_task` can still be `None`, for runner-less train_net).

Reviewed By: tglik

Differential Revision: D37011560

fbshipit-source-id: ce8a88242df0a16de8da97d94e8eb7def524c69c
parent 84dac84f
...@@ -6,7 +6,7 @@ import argparse ...@@ -6,7 +6,7 @@ import argparse
import logging import logging
import os import os
import time import time
from typing import Optional from typing import Optional, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -18,7 +18,7 @@ from d2go.config import ( ...@@ -18,7 +18,7 @@ from d2go.config import (
) )
from d2go.config.utils import get_diff_cfg from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import BaseRunner, create_runner from d2go.runner import BaseRunner, create_runner, DefaultTask
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
...@@ -195,7 +195,7 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str): ...@@ -195,7 +195,7 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
def setup_after_launch( def setup_after_launch(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner: Optional[BaseRunner] = None, runner: Union[BaseRunner, Type[DefaultTask], None],
): ):
""" """
Binary-level setup after entering DDP, including Binary-level setup after entering DDP, including
...@@ -215,13 +215,13 @@ def setup_after_launch( ...@@ -215,13 +215,13 @@ def setup_after_launch(
logger.info("Running with full config:\n{}".format(cfg)) logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml")) dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if runner: if isinstance(runner, BaseRunner):
logger.info("Initializing runner ...") logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg) runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner)) logger.info("Running with runner: {}".format(runner))
# save the diff config # save the diff config
if runner: if runner is not None:
default_cfg = runner.get_default_cfg() default_cfg = runner.get_default_cfg()
dump_cfg( dump_cfg(
get_diff_cfg(default_cfg, cfg), get_diff_cfg(default_cfg, cfg),
......
...@@ -145,7 +145,7 @@ def main( ...@@ -145,7 +145,7 @@ def main(
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
setup_after_launch(cfg, output_dir) setup_after_launch(cfg, output_dir, task_cls)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg) trainer_params = get_trainer_params(cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment