"tests/python/common/data/test_serialize.py" did not exist on "0fb13f7b9d7d59fdf7eaf03a3b00d6f31801cea5"
Commit 20e18edc authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Add a GPU memory snapshot profiler in d2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/542

## Overview
Add an option to enable GPU memory snapshot profiler in d2go. The profiler is natively supported by Pytorch and is able to record stack traces associated with all CUDA memory allocation/free events, allowing users to understand which parts of code contribute to the memory bottleneck. It also provides a powerful interactive web tool to visualize memory utilization ordered by time:
{F978609840}
Each colored block represents an allocated cuda memory block. User can click on the block to see the corresponding python stack trace that allocates the block.

## d2go integration
This diff integrates the profiler as a hook controlled by config key `USE_MEMORY_PROFILER`. The profiler will log snapshots and web tools to the output directory. There are three places that logging could happen: start of training, during training and OOM. Please read the docstring of `D2GoGpuMemorySnapshot` for more information.

Reviewed By: tglik, jaconey

Differential Revision: D45673764

fbshipit-source-id: 8900484a2266d94421fe3ee7a85a4dea3a9f6b72
parent 876c6756
...@@ -17,6 +17,7 @@ from d2go.modeling.subclass import add_subclass_configs ...@@ -17,6 +17,7 @@ from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs from d2go.quantization.modeling import add_quantization_default_configs
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.trainer.fsdp import add_fsdp_configs from d2go.trainer.fsdp import add_fsdp_configs
from d2go.utils.gpu_memory_profiler import add_memory_profiler_configs
from d2go.utils.visualization import add_tensorboard_default_configs from d2go.utils.visualization import add_tensorboard_default_configs
from detectron2.config import get_cfg as get_d2_cfg from detectron2.config import get_cfg as get_d2_cfg
from mobile_cv.common.misc.oss_utils import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
...@@ -112,6 +113,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -112,6 +113,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# Profiler # Profiler
_C.PROFILERS = ["default_flop_counter"] _C.PROFILERS = ["default_flop_counter"]
# GPU memory profiler
add_memory_profiler_configs(_C)
# Checkpointing-specific config # Checkpointing-specific config
_C.LOAD_CKPT_TO_GPU = False _C.LOAD_CKPT_TO_GPU = False
......
...@@ -35,10 +35,16 @@ from d2go.runner.config_defaults import ( ...@@ -35,10 +35,16 @@ from d2go.runner.config_defaults import (
get_detectron2go_runner_default_cfg, get_detectron2go_runner_default_cfg,
get_generalized_rcnn_runner_default_cfg, get_generalized_rcnn_runner_default_cfg,
) )
from d2go.runner.training_hooks import update_hooks_from_registry
from d2go.runner.training_hooks import (
D2GoGpuMemorySnapshot,
TRAINER_HOOKS_REGISTRY,
update_hooks_from_registry,
)
from d2go.trainer.fsdp import get_grad_scaler from d2go.trainer.fsdp import get_grad_scaler
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.gpu_memory_profiler import attach_oom_logger
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator
...@@ -136,6 +142,20 @@ def default_scale_quantization_configs(cfg, new_world_size): ...@@ -136,6 +142,20 @@ def default_scale_quantization_configs(cfg, new_world_size):
) )
@TRAINER_HOOKS_REGISTRY.register()
def add_memory_profiler_hook(hooks, cfg: CfgNode):
# Add GPU memory snapshot profiler to diagnose GPU OOM issues and benchmark memory usage during model training
if cfg.get("MEMORY_PROFILER", CfgNode()).get("ENABLED", False):
hooks.append(
D2GoGpuMemorySnapshot(
cfg.OUTPUT_DIR,
log_n_steps=cfg.MEMORY_PROFILER.LOG_N_STEPS,
log_during_train_at=cfg.MEMORY_PROFILER.LOG_DURING_TRAIN_AT,
trace_max_entries=cfg.MEMORY_PROFILER.TRACE_MAX_ENTRIES,
)
)
@fb_overwritable() @fb_overwritable()
def prepare_fb_model(cfg: CfgNode, model: torch.nn.Module) -> torch.nn.Module: def prepare_fb_model(cfg: CfgNode, model: torch.nn.Module) -> torch.nn.Module:
return model return model
...@@ -315,6 +335,12 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -315,6 +335,12 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
return model return model
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False):
# Attach memory profiler to GPU OOM events
if cfg.get("MEMORY_PROFILER", CfgNode()).get("ENABLED", False):
attach_oom_logger(
cfg.OUTPUT_DIR, trace_max_entries=cfg.MEMORY_PROFILER.TRACE_MAX_ENTRIES
)
model = self._build_model(cfg, eval_only) model = self._build_model(cfg, eval_only)
model = prepare_fb_model(cfg, model) model = prepare_fb_model(cfg, model)
......
...@@ -5,6 +5,8 @@ from typing import List ...@@ -5,6 +5,8 @@ from typing import List
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.utils.gpu_memory_profiler import log_memory_snapshot, record_memory_history
from detectron2.engine import HookBase from detectron2.engine import HookBase
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -21,3 +23,44 @@ def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode): ...@@ -21,3 +23,44 @@ def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode):
for name, hook_func in TRAINER_HOOKS_REGISTRY: for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...") logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks, cfg) hook_func(hooks, cfg)
class D2GoGpuMemorySnapshot(HookBase):
"""
A profiler that logs GPU memory snapshot during training.
There are three places that logging could happen:
1. start of training
d2go records memory snapshots before model instantiation and logs snapshots after `log_n_steps` iterations.
This is to capture the typical memory peak at model instantiation and the first few iterations
2. during training
d2go records memory snapshots at `log_during_train_at` iteration and logs snapshots after `log_n_steps` iterations.
This is to capture the stabilized memory utilization during training.
3. OOM
Right before OOM, the GPU memory snapshot will be logged to help diagnose OOM issues.
"""
def __init__(
self,
output_dir,
log_n_steps: int = 3,
log_during_train_at: int = 550,
trace_max_entries: int = 1000000,
) -> None:
self.output_dir = output_dir
self.step = 0
self.log_n_steps = log_n_steps
self.log_during_train_at = log_during_train_at
self.trace_max_entries = trace_max_entries
def before_step(self):
if self.trainer.iter == self.log_during_train_at:
record_memory_history(self.trace_max_entries)
def after_step(self):
if self.step == self.log_n_steps - 1:
log_memory_snapshot(self.output_dir, file_prefix=f"iter{self.trainer.iter}")
if self.trainer.iter == self.log_during_train_at + self.log_n_steps - 1:
log_memory_snapshot(self.output_dir, file_prefix=f"iter{self.trainer.iter}")
self.step += 1
import logging
import os
import pickle
import torch
from d2go.config import CfgNode as CN
from detectron2.utils.file_io import PathManager
from mobile_cv.torch.utils_pytorch import comm
from torch.cuda._memory_viz import segment_plot, trace_plot
logger: logging.Logger = logging.getLogger(__name__)
def add_memory_profiler_configs(_C: CN):
_C.MEMORY_PROFILER = CN()
_C.MEMORY_PROFILER.ENABLED = False
# max number of trace entries in memory snapshot
_C.MEMORY_PROFILER.MAX_ENTRIES = 1000000
# Configs to be used by d2go.utils.gpu_memory_profiler.D2GoGpuMemorySnapshot
# determine the number of iterations to log memory snapshots for
_C.MEMORY_PROFILER.LOG_N_STEPS = 3
# determine at what iteration to start recording gpu memory
_C.MEMORY_PROFILER.LOG_DURING_TRAIN_AT = 550
def omm_logger_wrapper(output_dir):
def oom_logger(
device: int, alloc: int, device_alloc: int, device_free: int
) -> None:
"""
Log memory snapshot in the event of CUDA OOM.
"""
logger.info(
f"Saving memory snapshot device: {device}, alloc: {alloc}, device_alloc: {device_alloc}, device_free: {device_free}"
)
try:
log_memory_snapshot(output_dir, file_prefix="oom")
except Exception as e:
logger.error(f"Failed to log memory snapshot during OOM {e}")
return oom_logger
def log_memory_snapshot(output_dir: str, file_prefix: str = "") -> None:
"""
Log memory snapshots to output_dir
"""
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not logging snapshot")
return
try:
rank = comm.get_rank()
save_dir = os.path.join(
output_dir, "memory_snapshot", f"{file_prefix}_rank{rank}"
)
logger.info(f"Logging memory snapshot to {save_dir}")
snapshot = torch.cuda.memory._snapshot()
dump_snapshot(save_dir, snapshot)
except Exception as e:
logger.error(f"Failed to log memory snapshot to {save_dir}: {e}")
def dump_snapshot(save_dir: str, snapshot):
"""
Dump memory snapshot and useful plots to save_dir.
This is a rewrite of torch.cuda.memory._dump_snapshot() with PathManager.
"""
if not PathManager.exists(save_dir):
PathManager.mkdirs(save_dir)
with PathManager.open(os.path.join(save_dir, "snapshot.pickle"), "wb") as f:
pickle.dump(snapshot, f)
with PathManager.open(os.path.join(save_dir, "trace_plot.html"), "w") as f:
f.write(trace_plot(snapshot))
with PathManager.open(os.path.join(save_dir, "segment_plot.html"), "w") as f:
f.write(segment_plot(snapshot))
logger.info(f"Saved memory snapshot to {save_dir}")
def record_memory_history(trace_max_entries=1000000) -> None:
"""
Start recording memory history and stack traces.
"""
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
torch.cuda.memory._record_memory_history(
enabled="all", max_entries=trace_max_entries
)
logger.info("Started recording memory history")
def attach_oom_logger(output_dir, trace_max_entries=1000000) -> None:
"""
Start recording memory history and attach the OOM logger.
"""
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not attaching OOM logger")
return
record_memory_history(trace_max_entries)
torch._C._cuda_attach_out_of_memory_observer(omm_logger_wrapper(output_dir))
logger.info("Attached GPU OOM logger")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment