Commit 29b57165 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Read number of processes from dist_config

Summary: Currently when launching a training flow, we read number of processes from resources.num_gpus. To be backward compatible with existing D2 (https://github.com/facebookresearch/d2go/commit/f82d44d3c33e6c781a3c6f2b27b376fdfbaeda53)Go training config, this diff changes to dist_config.num_processes_per_machine instead.

Reviewed By: wat3rBro

Differential Revision: D28630334

fbshipit-source-id: 3c684cd56e5d2e247c7b82e1d1eeff0f39e59ee4
parent f82d44d3
...@@ -18,7 +18,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestLightningTrainNet(unittest.TestCase):
# set distributed backend to none to avoid spawning child process, # set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset # which doesn't inherit the temporary dataset
patcher = unittest.mock.patch( patcher = unittest.mock.patch(
"d2go.tools.lightning_train_net.get_accelerator", return_value=None "d2go.tools.lightning_train_net._get_accelerator", return_value=None
) )
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
patcher.start() patcher.start()
...@@ -28,7 +28,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestLightningTrainNet(unittest.TestCase):
@tempdir @tempdir
def test_train_net_main(self, root_dir): def test_train_net_main(self, root_dir):
""" tests the main training entry point. """ """tests the main training entry point."""
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process, # set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset # which doesn't inherit the temporary dataset
...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
@tempdir @tempdir
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
""" tests saving and loading from checkpoint. """ """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
out = main(cfg) out = main(cfg)
......
...@@ -12,7 +12,6 @@ from d2go.config import CfgNode, temp_defrost, auto_scale_world_size ...@@ -12,7 +12,6 @@ from d2go.config import CfgNode, temp_defrost, auto_scale_world_size
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import ( from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining, QuantizationAwareTraining,
ModelTransform,
) )
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser from d2go.setup import basic_argument_parser
...@@ -70,10 +69,30 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -70,10 +69,30 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
callbacks.append(QuantizationAwareTraining.from_config(cfg)) callbacks.append(QuantizationAwareTraining.from_config(cfg))
return callbacks return callbacks
def _get_accelerator(use_gpu: bool) -> str:
return "ddp" if use_gpu else "ddp_cpu"
def get_accelerator(device: str) -> str:
return "ddp_cpu" if device.lower() == "cpu" else "ddp"
def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> Dict[str, Any]:
use_gpu = cfg.MODEL.DEVICE.lower() == "gpu"
return {
# training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first
"max_epochs": 10 ** 8,
"max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines,
"gpus": num_processes if use_gpu else None,
"num_processes": num_processes,
"accelerator": _get_accelerator(use_gpu),
"callbacks": _get_trainer_callbacks(cfg),
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
"progress_bar_refresh_rate": 10,
"terminate_on_nan": True,
}
def do_train( def do_train(
cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask
...@@ -123,45 +142,20 @@ def main( ...@@ -123,45 +142,20 @@ def main(
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask, task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
eval_only: bool = False, eval_only: bool = False,
num_machines: int = 1, num_machines: int = 1,
num_gpus: int = 0,
num_processes: int = 1, num_processes: int = 1,
) -> TrainOutput: ) -> TrainOutput:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
Args: Args:
cfg: D2go config node cfg: D2go config node
num_machines: Number of nodes used for distributed training num_machines: Number of nodes used for distributed training
num_gpus: Number of GPUs to train on each node
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
NOTE: Automatically set to the number of GPUs when using DDP.
Set a value greater than 1 to mimic distributed training on CPUs.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
assert ( auto_scale_world_size(cfg, num_machines * num_processes)
num_processes == 1 or num_gpus == 0
), "Only set num_processes > 1 when training on CPUs"
auto_scale_world_size(cfg, num_machines * num_gpus)
maybe_override_output_dir(cfg, output_dir) maybe_override_output_dir(cfg, output_dir)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR) trainer_params = get_trainer_params(cfg, num_machines, num_processes)
trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first
"max_epochs": 10 ** 8,
"max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines,
"gpus": num_gpus,
"num_processes": num_processes,
"accelerator": get_accelerator(cfg.MODEL.DEVICE),
"callbacks": _get_trainer_callbacks(cfg),
"logger": tb_logger,
"num_sanity_val_steps": 0,
"progress_bar_refresh_rate": 10,
}
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
if PathManager.exists(last_checkpoint): if PathManager.exists(last_checkpoint):
...@@ -178,7 +172,7 @@ def main( ...@@ -178,7 +172,7 @@ def main(
return TrainOutput( return TrainOutput(
output_dir=cfg.OUTPUT_DIR, output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_logger.log_dir, tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res, accuracy=task.eval_res,
model_configs=model_configs, model_configs=model_configs,
) )
...@@ -221,7 +215,6 @@ if __name__ == "__main__": ...@@ -221,7 +215,6 @@ if __name__ == "__main__":
task_cls, task_cls,
eval_only=False, # eval_only eval_only=False, # eval_only
num_machines=args.num_machines, num_machines=args.num_machines,
num_gpus=args.num_gpus,
num_processes=args.num_processes, num_processes=args.num_processes,
) )
if get_rank() == 0: if get_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment