Unverified Commit 4d84f886 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Refactor `--debug-tensor-dump-layers` to list (#12691)

parent dc4f5418
......@@ -13,6 +13,7 @@ The file contains a series of key-value pairs, where the keys correspond to oper
import logging
import os
from pathlib import Path
from typing import List, Optional
import torch
......@@ -24,7 +25,12 @@ logger = logging.getLogger(__name__)
class TensorDumper:
def __init__(
self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
self,
dump_dir: str,
dump_layers: Optional[List[int]],
tp_size: int,
tp_rank: int,
pp_rank: int,
):
self._dump_layers = dump_layers
self._forward_pass_id = 0
......@@ -94,11 +100,15 @@ class TensorDumper:
top_level_model = True
else:
cur_name = prefix + "." + name
if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
if (
self._dump_layers is not None
and name.isdigit()
and prefix == layers_prefix
):
# If we only need n layers, skip the reset layers.
# Most models' layout is like model.layers.0.
cur_layer = int(name)
if cur_layer >= self._dump_layers:
if cur_layer not in self._dump_layers:
continue
if module is not None:
_, sub_count = self._add_hook_recursive(
......@@ -129,7 +139,12 @@ class TensorDumper:
def register_forward_hook_for_model(
model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
model,
dump_dir: str,
dump_layers: Optional[List[int]],
tp_size: int,
tp_rank: int,
pp_rank: int,
):
tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
# Most models have the layerout like:
......
......@@ -519,8 +519,8 @@ class ServerArgs:
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
# -1 mean dump all layers.
debug_tensor_dump_layers: int = -1
# None means dump all layers.
debug_tensor_dump_layers: Optional[List[int]] = None
# TODO(guoyuhong): clean the old dumper code.
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
......@@ -3424,8 +3424,8 @@ class ServerArgs:
parser.add_argument(
"--debug-tensor-dump-layers",
type=int,
default=-1,
help="The layer number for dumping tensors.",
nargs="+",
help="The layer ids to dump. Dump all layers if not specified.",
)
parser.add_argument(
"--debug-tensor-dump-input-file",
......
......@@ -13,6 +13,7 @@ from sglang.srt.distributed.parallel_state import (
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import LinearBase
from sglang.srt.models.qwen2 import Qwen2MLP
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.srt.utils import add_prefix
TEST_HIDDEN_SIZE = 32
......@@ -63,6 +64,7 @@ def init_weights(module):
def test_model_forward_dump(tmp_path):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
init_distributed_environment(
backend="nccl",
world_size=1,
......@@ -75,7 +77,7 @@ def test_model_forward_dump(tmp_path):
model.apply(init_weights)
model = model.cuda().bfloat16()
dumper = register_forward_hook_for_model(
model, tmp_path / "sglang_dump", -1, 0, 0, 0
model, tmp_path / "sglang_dump", [0], 0, 0, 0
)
dir_path = dumper.get_dump_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment