"...text-generation-inference.git" did not exist on "ffe05ccd0566bba3bee6f9fb1678d193c3392ec5"
Commit e87ed5f0 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Auto scale config for multi-node training

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/62

Lightning trainer set max step to cfg.SOLVER.MAX_ITER. However, this is the max iteration for all nodes, in multi-node training, we need to scale it down, as well as eval period and other configs.
This diff calls `auto_scale_world_size` before passing the config to trainer.

Reviewed By: wat3rBro

Differential Revision: D28140877

fbshipit-source-id: 2639ae58773a4ec2a0cc59dfefb2f5d9b1afe1a8
parent f3d05021
...@@ -22,7 +22,7 @@ from d2go.runner.default_runner import ( ...@@ -22,7 +22,7 @@ from d2go.runner.default_runner import (
Detectron2GoRunner, Detectron2GoRunner,
GeneralizedRCNNRunner, GeneralizedRCNNRunner,
) )
from d2go.setup import setup_after_launch from d2go.setup import setup_after_lightning_launch
from d2go.utils.ema_state import EMAState from d2go.utils.ema_state import EMAState
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from detectron2.modeling import build_model from detectron2.modeling import build_model
...@@ -276,7 +276,7 @@ class DefaultTask(pl.LightningModule): ...@@ -276,7 +276,7 @@ class DefaultTask(pl.LightningModule):
# Runner methods # Runner methods
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def setup(self, stage: str): def setup(self, stage: str):
setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None) setup_after_lightning_launch(self.cfg, self.cfg.OUTPUT_DIR)
def register(self, cfg: CfgNode): def register(self, cfg: CfgNode):
inject_coco_datasets(cfg) inject_coco_datasets(cfg)
......
...@@ -179,7 +179,7 @@ def prepare_for_launch(args): ...@@ -179,7 +179,7 @@ def prepare_for_launch(args):
return cfg, output_dir, runner return cfg, output_dir, runner
def setup_after_launch(cfg, output_dir, runner): def _setup_after_launch(cfg: CN, output_dir: str, runner):
""" """
Set things up after entering DDP, including Set things up after entering DDP, including
- creating working directory - creating working directory
...@@ -199,14 +199,20 @@ def setup_after_launch(cfg, output_dir, runner): ...@@ -199,14 +199,20 @@ def setup_after_launch(cfg, output_dir, runner):
) )
) )
cfg.OUTPUT_DIR = output_dir cfg.OUTPUT_DIR = output_dir
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
def setup_after_launch(cfg: CN, output_dir: str, runner):
_setup_after_launch(cfg, output_dir, runner)
logger.info("Initializing runner ...") logger.info("Initializing runner ...")
runner = initialize_runner(runner, cfg) runner = initialize_runner(runner, cfg)
log_info(cfg, runner) log_info(cfg, runner)
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
auto_scale_world_size(cfg, new_world_size=comm.get_world_size()) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
def setup_after_lightning_launch(cfg: CN, output_dir: str):
_setup_after_launch(cfg, output_dir, runner=None)
log_info(cfg, runner=None)
@run_once() @run_once()
def setup_loggers(output_dir, color=None): def setup_loggers(output_dir, color=None):
......
...@@ -8,7 +8,7 @@ from dataclasses import dataclass ...@@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost, auto_scale_world_size
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import ( from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining, QuantizationAwareTraining,
...@@ -139,7 +139,7 @@ def main( ...@@ -139,7 +139,7 @@ def main(
assert ( assert (
num_processes == 1 or num_gpus == 0 num_processes == 1 or num_gpus == 0
), "Only set num_processes > 1 when training on CPUs" ), "Only set num_processes > 1 when training on CPUs"
auto_scale_world_size(cfg, num_machines * num_gpus)
maybe_override_output_dir(cfg, output_dir) maybe_override_output_dir(cfg, output_dir)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment