Commit 82c6a50b authored by Haroun Habeeb's avatar Haroun Habeeb Committed by Facebook GitHub Bot
Browse files

upgrade lightning's API for Trainer in d2go

Summary:
see https://fb.workplace.com/notes/3006074566389155

----
did the integration test not catch this?

Reviewed By: ananthsub, tangbinh

Differential Revision: D34665501

fbshipit-source-id: ff2cbfa9462f131455dce46a0c413c4c69105f48
parent 56a2eda1
......@@ -18,20 +18,12 @@ from d2go.setup import basic_argument_parser
from d2go.utils.misc import dump_trained_model_configs
from detectron2.utils.events import EventStorage
from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import Callback, TQDMProgressBar, LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from torch.distributed import get_rank
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
assert os.getenv("OSSRUN") == "1"
# FIXME: DDPStrategy has been renamed to DDPStrategy, however internal version is
# not updated yet, temporally skipping the import in oss env in order to unblock
# CI where DPP is not used.
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("detectron2go.lightning.train_net")
......@@ -67,6 +59,7 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
A list of configured Callbacks to be used by the Lightning Trainer.
"""
callbacks: List[Callback] = [
TQDMProgressBar(refresh_rate=10), # Arbitrary refresh_rate.
LearningRateMonitor(logging_interval="step"),
ModelCheckpoint(
dirpath=cfg.OUTPUT_DIR,
......@@ -78,39 +71,35 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks
def _get_strategy(cfg: CfgNode) -> DDPStrategy:
return DDPStrategy(find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS)
def _get_accelerator(use_cpu: bool) -> str:
return "ddp_cpu" if use_cpu else "ddp"
return "cpu" if use_cpu else "gpu"
def get_trainer_params(
cfg: CfgNode, num_machines: int, num_processes: int
) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg)
accelerator = _get_accelerator(use_cpu)
plugins = []
if accelerator:
plugins.append(
DDPStrategy(find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS)
)
return {
# training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first
"max_epochs": 10 ** 8,
"max_epochs": -1,
"max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines,
"gpus": None if use_cpu else num_processes,
"num_processes": num_processes,
"devices": num_processes,
"strategy": strategy,
"accelerator": accelerator,
"callbacks": _get_trainer_callbacks(cfg),
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
"progress_bar_refresh_rate": 10,
"replace_sampler_ddp": False,
"plugins": plugins,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment