Commit 87649f4f authored by Wei Sun's avatar Wei Sun Committed by Facebook GitHub Bot
Browse files

Use the consolidated snapshot API in Unitrace to support Zoomer

Summary: Similar to D48210543. Update the training_hooks to use the Unitrace memory snapshot APIs. This allows us to maintain a singel path for memory snapshot APIs, and also collect important details such as snapshot location for Zoomer.

Pulled By:
HugeEngine

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/636

Reviewed By: frabu6, aaronenyeshi, jackiexu1992, mengluy0125

Differential Revision: D48368150

fbshipit-source-id: b279adfa29d390e615d2c32a7ab9e05d95b4f164
parent 8d072ebf
...@@ -38,11 +38,7 @@ from d2go.runner.config_defaults import ( ...@@ -38,11 +38,7 @@ from d2go.runner.config_defaults import (
get_generalized_rcnn_runner_default_cfg, get_generalized_rcnn_runner_default_cfg,
) )
from d2go.runner.training_hooks import ( from d2go.runner.training_hooks import update_hooks_from_registry
D2GoGpuMemorySnapshot,
TRAINER_HOOKS_REGISTRY,
update_hooks_from_registry,
)
from d2go.trainer.fsdp import get_grad_scaler from d2go.trainer.fsdp import get_grad_scaler
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.abnormal_checker import ( from d2go.utils.abnormal_checker import (
...@@ -150,20 +146,6 @@ def default_scale_quantization_configs(cfg, new_world_size): ...@@ -150,20 +146,6 @@ def default_scale_quantization_configs(cfg, new_world_size):
) )
@TRAINER_HOOKS_REGISTRY.register()
def add_memory_profiler_hook(hooks, cfg: CfgNode):
# Add GPU memory snapshot profiler to diagnose GPU OOM issues and benchmark memory usage during model training
if cfg.get("MEMORY_PROFILER", CfgNode()).get("ENABLED", False):
hooks.append(
D2GoGpuMemorySnapshot(
cfg.OUTPUT_DIR,
log_n_steps=cfg.MEMORY_PROFILER.LOG_N_STEPS,
log_during_train_at=cfg.MEMORY_PROFILER.LOG_DURING_TRAIN_AT,
trace_max_entries=cfg.MEMORY_PROFILER.TRACE_MAX_ENTRIES,
)
)
@fb_overwritable() @fb_overwritable()
def prepare_fb_model(cfg: CfgNode, model: torch.nn.Module) -> torch.nn.Module: def prepare_fb_model(cfg: CfgNode, model: torch.nn.Module) -> torch.nn.Module:
return model return model
......
...@@ -5,8 +5,6 @@ from typing import List ...@@ -5,8 +5,6 @@ from typing import List
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.utils.gpu_memory_profiler import log_memory_snapshot, record_memory_history
from detectron2.engine.train_loop import HookBase from detectron2.engine.train_loop import HookBase
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -23,47 +21,3 @@ def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode): ...@@ -23,47 +21,3 @@ def update_hooks_from_registry(hooks: List[HookBase], cfg: CfgNode):
for name, hook_func in TRAINER_HOOKS_REGISTRY: for name, hook_func in TRAINER_HOOKS_REGISTRY:
logger.info(f"Update trainer hooks from {name}...") logger.info(f"Update trainer hooks from {name}...")
hook_func(hooks, cfg) hook_func(hooks, cfg)
class D2GoGpuMemorySnapshot(HookBase):
"""
A profiler that logs GPU memory snapshot during training.
There are three places that logging could happen:
1. start of training
d2go records memory snapshots before model instantiation and logs snapshots after `log_n_steps` iterations.
This is to capture the typical memory peak at model instantiation and the first few iterations
2. during training
d2go records memory snapshots at `log_during_train_at` iteration and logs snapshots after `log_n_steps` iterations.
This is to capture the stabilized memory utilization during training.
3. OOM
Right before OOM, the GPU memory snapshot will be logged to help diagnose OOM issues.
"""
def __init__(
self,
output_dir,
log_n_steps: int = 3,
log_during_train_at: int = 550,
trace_max_entries: int = 1000000,
) -> None:
self.output_dir = output_dir
self.step = 0
self.log_n_steps = log_n_steps
self.log_during_train_at = log_during_train_at
self.trace_max_entries = trace_max_entries
logger.warning(
"WARNING: Memory snapshot profiler is enabled. This may cause ranks to die and training jobs to get stuck. Please use with caution."
)
def before_step(self):
if self.trainer.iter == self.log_during_train_at:
record_memory_history(self.trace_max_entries)
def after_step(self):
if self.step == self.log_n_steps - 1:
log_memory_snapshot(self.output_dir, file_prefix=f"iter{self.trainer.iter}")
if self.trainer.iter == self.log_during_train_at + self.log_n_steps - 1:
log_memory_snapshot(self.output_dir, file_prefix=f"iter{self.trainer.iter}")
self.step += 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment