Commit 8cf2b879 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support diff config for lightning_train_net

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/285

`setup_after_launch` can now take `DefaultTask` as well (the `runner_or_task` can still be `None`, for runner-less train_net).

Reviewed By: tglik

Differential Revision: D37011560

fbshipit-source-id: ce8a88242df0a16de8da97d94e8eb7def524c69c
parent 84dac84f
......@@ -6,7 +6,7 @@ import argparse
import logging
import os
import time
from typing import Optional
from typing import Optional, Type, Union
import detectron2.utils.comm as comm
import torch
......@@ -18,7 +18,7 @@ from d2go.config import (
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import BaseRunner, create_runner
from d2go.runner import BaseRunner, create_runner, DefaultTask
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
......@@ -195,7 +195,7 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
def setup_after_launch(
cfg: CfgNode,
output_dir: str,
runner: Optional[BaseRunner] = None,
runner: Union[BaseRunner, Type[DefaultTask], None],
):
"""
Binary-level setup after entering DDP, including
......@@ -215,13 +215,13 @@ def setup_after_launch(
logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if runner:
if isinstance(runner, BaseRunner):
logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner))
# save the diff config
if runner:
if runner is not None:
default_cfg = runner.get_default_cfg()
dump_cfg(
get_diff_cfg(default_cfg, cfg),
......
......@@ -145,7 +145,7 @@ def main(
num_processes: Number of processes on each node.
eval_only: True if run evaluation only.
"""
setup_after_launch(cfg, output_dir)
setup_after_launch(cfg, output_dir, task_cls)
task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment