Commit 53748d9d authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Add profiler to d2go lightning

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/589

Allow attaching GPU profiler to lightning d2go tasks.

Reviewed By: miqueljubert

Differential Revision: D47190798

fbshipit-source-id: b10269d25de6b5f977633796e77b0d6d912a873a
parent a2b9a523
...@@ -25,7 +25,9 @@ from d2go.runner.default_runner import ( ...@@ -25,7 +25,9 @@ from d2go.runner.default_runner import (
) )
from d2go.utils.ema_state import EMAState from d2go.utils.ema_state import EMAState
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
from detectron2.engine.train_loop import HookBase
from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from mobile_cv.common.misc.oss_utils import fb_overwritable
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.logger import _flatten_dict from pytorch_lightning.utilities.logger import _flatten_dict
...@@ -94,6 +96,11 @@ class ModelTag(str, Enum): ...@@ -94,6 +96,11 @@ class ModelTag(str, Enum):
EMA = "ema" EMA = "ema"
@fb_overwritable()
def get_gpu_profiler(cfg: CfgNode) -> Optional[HookBase]:
return None
class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule): class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
def __init__(self, cfg: CfgNode): def __init__(self, cfg: CfgNode):
super().__init__() super().__init__()
...@@ -119,6 +126,7 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule): ...@@ -119,6 +126,7 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE, device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE,
) )
self.dataset_evaluators[ModelTag.EMA] = [] self.dataset_evaluators[ModelTag.EMA] = []
self.gpu_profiler: Optional[HookBase] = get_gpu_profiler(cfg)
def _build_model(self) -> torch.nn.Module: def _build_model(self) -> torch.nn.Module:
model = build_meta_arch(self.cfg) model = build_meta_arch(self.cfg)
...@@ -376,9 +384,16 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule): ...@@ -376,9 +384,16 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE, device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE,
) )
def on_train_batch_start(self, *_) -> None:
if self.gpu_profiler is not None:
self.gpu_profiler.before_step()
def on_train_batch_end(self, *_) -> None: def on_train_batch_end(self, *_) -> None:
if self.ema_state: if self.ema_state:
self.ema_state.update(self.model) self.ema_state.update(self.model)
if self.gpu_profiler is not None:
# NOTE: keep this last in function to include all ops in this iteration of the trace
self.gpu_profiler.after_step()
def on_test_epoch_start(self): def on_test_epoch_start(self):
self._on_evaluation_epoch_start() self._on_evaluation_epoch_start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment