"git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "ff7cf01a81fbd2d343b4b17b7ada95f6cb2ee940"
Commit 236b15cd authored by Daniel Li (AI)'s avatar Daniel Li (AI) Committed by Facebook GitHub Bot
Browse files

Set find_unused_parameters according to DDP_FIND_UNUSED_PARAMETERS

Summary: Set find_unused_parameters according to DDP_FIND_UNUSED_PARAMETERS with DDPPlugin

Reviewed By: kazhang

Differential Revision: D29567013

fbshipit-source-id: f3ffac566a2ff046f55e692b3b24f9531913d4d4
parent 80c18641
...@@ -22,6 +22,7 @@ from pytorch_lightning.callbacks import Callback ...@@ -22,6 +22,7 @@ from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from torch.distributed import get_rank from torch.distributed import get_rank
...@@ -75,6 +76,11 @@ def _get_accelerator(use_cpu: bool) -> str: ...@@ -75,6 +76,11 @@ def _get_accelerator(use_cpu: bool) -> str:
def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> Dict[str, Any]: def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu" use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
accelerator = _get_accelerator(use_cpu)
plugins = []
if accelerator:
plugins.append(DDPPlugin(find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS))
return { return {
# training loop is bounded by max steps, use a large max_epochs to make # training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first # sure max_steps is met first
...@@ -86,13 +92,14 @@ def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> D ...@@ -86,13 +92,14 @@ def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> D
"num_nodes": num_machines, "num_nodes": num_machines,
"gpus": None if use_cpu else num_processes, "gpus": None if use_cpu else num_processes,
"num_processes": num_processes, "num_processes": num_processes,
"accelerator": _get_accelerator(use_cpu), "accelerator": accelerator,
"callbacks": _get_trainer_callbacks(cfg), "callbacks": _get_trainer_callbacks(cfg),
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"progress_bar_refresh_rate": 10, "progress_bar_refresh_rate": 10,
"terminate_on_nan": True, "terminate_on_nan": True,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"plugins": plugins,
} }
def do_train( def do_train(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment