Unverified Commit ab95d35f authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

feat: Add Non-intrusive Tensor Dumping for Model Inference (#10566)

parent 34c286b8
"""
This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
Usage:
Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
The file contains a series of key-value pairs, where the keys correspond to operator names in the model
(similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
"""
import logging
import os
from pathlib import Path
import torch
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
logger = logging.getLogger(__name__)
class TensorDumper:
def __init__(
self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
):
self._dump_layers = dump_layers
self._forward_pass_id = 0
self._pid = os.getpid()
self._current_tensors = {}
self._base_dir = Path(dump_dir)
rank = tp_size * pp_rank + tp_rank
self._process_dir = (
self._base_dir / f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}"
)
self._process_dir.mkdir(parents=True, exist_ok=True)
def get_dump_dir(self):
return str(self._process_dir)
def add_tensor(self, name, tensor_item):
if isinstance(tensor_item, (tuple, list)):
tensors = [t.cpu() for t in tensor_item if t is not None]
if len(tensors) == 1:
self._current_tensors[name] = tensors[0]
else:
self._current_tensors[name] = tensors
elif isinstance(tensor_item, torch.Tensor):
self._current_tensors[name] = tensor_item.cpu()
elif isinstance(tensor_item, LogitsProcessorOutput):
self._current_tensors[name] = tensor_item.next_token_logits.cpu()
elif isinstance(tensor_item, ForwardBatch):
self._current_tensors[name + ".forward_batch_info.input_ids"] = (
tensor_item.input_ids.cpu()
)
self._current_tensors[name + ".forward_batch_info.seq_lens"] = (
tensor_item.seq_lens.cpu()
)
self._current_tensors[name + ".forward_batch_info.positions"] = (
tensor_item.positions.cpu()
)
elif isinstance(tensor_item, PPProxyTensors):
for tensor_name in tensor_item.tensors.keys():
self._current_tensors[name + ".pp_proxy_tensors." + tensor_name] = (
tensor_item.tensors[tensor_name].cpu()
)
else:
logger.warning(f"Unsupported type: {type(tensor_item)}: {tensor_item}")
def dump_current_tensors(self):
if len(self._current_tensors) == 0:
return
tensor_file_for_pass = self._process_dir / f"Pass{self._forward_pass_id:05d}.pt"
logger.info(
f"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}"
)
torch.save(self._current_tensors, str(tensor_file_for_pass))
self._current_tensors = {}
self._forward_pass_id += 1
def _add_hook_recursive(
self, model, prefix, top_level_module_name, layers_module_name
):
model_top_level_module_matched = False
layers_prefix = top_level_module_name + "." + layers_module_name
for name, module in model._modules.items():
top_level_model = False
if len(prefix) == 0:
cur_name = name
if cur_name == top_level_module_name:
model_top_level_module_matched = True
top_level_model = True
else:
cur_name = prefix + "." + name
if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
# If we only need n layers, skip the reset layers.
# Most models' layout is like model.layers.0.
cur_layer = int(name)
if cur_layer >= self._dump_layers:
continue
if module is not None:
_, sub_count = self._add_hook_recursive(
module, cur_name, top_level_module_name, layers_module_name
)
if sub_count == 0 or top_level_model:
# Avoid duplicated output hooks, e.g. self_attn may contain:
# self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
# Therefore, we do not need to add output hooks for self_attn,
# since the output of self_attn should be the same to self_attn.o_proj.
module.register_forward_hook(
self._dump_hook(cur_name, top_level_model)
)
return model_top_level_module_matched, len(model._modules.items())
def _dump_hook(self, tensor_name, do_dump):
def inner_dump_hook(module, input, output):
if do_dump:
# This is the top-level model, so we will record the input for it.
for item in input:
if isinstance(item, ForwardBatch):
self.add_tensor(tensor_name, item)
self.dump_current_tensors()
if output is not None:
self.add_tensor(tensor_name, output)
return inner_dump_hook
def register_forward_hook_for_model(
model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
):
tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
# Most models have the layerout like:
# XxxxForCausalLM
# (model): XxxxModel
# (layers): ModuleList
# If the model is not constructed with this layout,
# environment variable can be used to specify the module names.
top_level_module_name = os.getenv("TENSOR_DUMP_TOP_LEVEL_MODULE_NAME", "model")
layers_module_name = os.getenv("TENSOR_DUMP_LAYERS_MODULE_NAME", "layers")
model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(
model, "", top_level_module_name, layers_module_name
)
assert (
model_top_level_module_matched
), f"model should have a module named {top_level_module_name}"
return tensor_dumper
...@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import ( ...@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import (
get_dp_device, get_dp_device,
get_dp_dtype, get_dp_dtype,
get_dp_hidden_size, get_dp_hidden_size,
get_local_attention_dp_size,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
...@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend from sglang.srt.utils import is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -252,10 +251,6 @@ class LogitsProcessor(nn.Module): ...@@ -252,10 +251,6 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
def compute_logprobs_for_multi_item_scoring( def compute_logprobs_for_multi_item_scoring(
self, self,
input_ids, input_ids,
...@@ -463,14 +458,6 @@ class LogitsProcessor(nn.Module): ...@@ -463,14 +458,6 @@ class LogitsProcessor(nn.Module):
logits[sample_indices] if sample_indices is not None else logits logits[sample_indices] if sample_indices is not None else logits
) )
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather
or get_local_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
hidden_states_to_store: Optional[torch.Tensor] = None hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture(): if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full(): if logits_metadata.capture_hidden_mode.is_full():
......
...@@ -40,6 +40,9 @@ from sglang.srt.configs.model_config import ( ...@@ -40,6 +40,9 @@ from sglang.srt.configs.model_config import (
) )
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.debug_utils.tensor_dump_forward_hook import (
register_forward_hook_for_model,
)
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
...@@ -791,6 +794,15 @@ class ModelRunner: ...@@ -791,6 +794,15 @@ class ModelRunner:
f"avail mem={after_avail_memory:.2f} GB, " f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={self.weight_load_mem_usage:.2f} GB." f"mem usage={self.weight_load_mem_usage:.2f} GB."
) )
if self.server_args.debug_tensor_dump_output_folder is not None:
register_forward_hook_for_model(
self.model,
self.server_args.debug_tensor_dump_output_folder,
self.server_args.debug_tensor_dump_layers,
self.tp_size,
self.tp_rank,
self.pp_rank,
)
if self.server_args.elastic_ep_backend == "mooncake": if self.server_args.elastic_ep_backend == "mooncake":
# Mooncake does not support `monitored_barrier` # Mooncake does not support `monitored_barrier`
......
...@@ -511,6 +511,9 @@ class ServerArgs: ...@@ -511,6 +511,9 @@ class ServerArgs:
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
# -1 mean dump all layers.
debug_tensor_dump_layers: int = -1
# TODO(guoyuhong): clean the old dumper code.
debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False debug_tensor_dump_inject: bool = False
...@@ -1784,7 +1787,13 @@ class ServerArgs: ...@@ -1784,7 +1787,13 @@ class ServerArgs:
) )
def _handle_other_validations(self): def _handle_other_validations(self):
pass # Handle model inference tensor dump.
if self.debug_tensor_dump_output_folder is not None:
logger.warning(
"Cuda graph and server warmup are disabled because of using tensor dump mode"
)
self.disable_cuda_graph = True
self.skip_server_warmup = True
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -3375,6 +3384,12 @@ class ServerArgs: ...@@ -3375,6 +3384,12 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_output_folder, default=ServerArgs.debug_tensor_dump_output_folder,
help="The output folder for dumping tensors.", help="The output folder for dumping tensors.",
) )
parser.add_argument(
"--debug-tensor-dump-layers",
type=int,
default=-1,
help="The layer number for dumping tensors.",
)
parser.add_argument( parser.add_argument(
"--debug-tensor-dump-input-file", "--debug-tensor-dump-input-file",
type=str, type=str,
......
import unittest
import torch
from torch import nn
from sglang.srt.debug_utils.tensor_dump_forward_hook import (
register_forward_hook_for_model,
)
from sglang.srt.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import LinearBase
from sglang.srt.models.qwen2 import Qwen2MLP
from sglang.srt.utils import add_prefix
TEST_HIDDEN_SIZE = 32
class SimpleModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.hidden_size = TEST_HIDDEN_SIZE
self.rms_norm_eps = 1e-5
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=self.hidden_size,
hidden_act="silu",
quant_config=None,
prefix=add_prefix("mlp", ""),
)
self.layernorm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
return hidden_states
class MockCausalLM(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = SimpleModel()
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.model(hidden_states)
def init_weights(module):
if isinstance(module, LinearBase):
torch.nn.init.uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, RMSNorm):
torch.nn.init.ones_(module.weight)
def test_model_forward_dump(tmp_path):
init_distributed_environment(
backend="nccl",
world_size=1,
rank=0,
local_rank=0,
distributed_init_method="tcp://127.0.0.1:2646",
)
initialize_model_parallel()
model = MockCausalLM()
model.apply(init_weights)
model = model.cuda().bfloat16()
dumper = register_forward_hook_for_model(
model, tmp_path / "sglang_dump", -1, 0, 0, 0
)
dir_path = dumper.get_dump_dir()
inp = torch.randn(4, TEST_HIDDEN_SIZE, dtype=torch.bfloat16) * 0.01
result = model(inp.cuda())
data = torch.load(f"{dir_path}/Pass00000.pt")
assert "model.layernorm" in data
assert "model.mlp.down_proj" in data
assert torch.allclose(
data["model.mlp.down_proj"], result.cpu(), rtol=1e-5, atol=1e-5
)
if __name__ == "__main__":
unittest.main()
...@@ -14,6 +14,7 @@ class TestFile: ...@@ -14,6 +14,7 @@ class TestFile:
# NOTE: please sort the test cases alphabetically by the test file name # NOTE: please sort the test cases alphabetically by the test file name
suites = { suites = {
"per-commit-1-gpu": [ "per-commit-1-gpu": [
TestFile("debug_utils/test_tensor_dump_forward_hook.py", 15),
TestFile("function_call/test_json_schema_constraint.py", 30), TestFile("function_call/test_json_schema_constraint.py", 30),
TestFile("hicache/test_hicache.py", 116), TestFile("hicache/test_hicache.py", 116),
TestFile("hicache/test_hicache_eagle.py", 150), TestFile("hicache/test_hicache_eagle.py", 150),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment