Unverified Commit 4d84f886 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Refactor `--debug-tensor-dump-layers` to list (#12691)

parent dc4f5418
...@@ -13,6 +13,7 @@ The file contains a series of key-value pairs, where the keys correspond to oper ...@@ -13,6 +13,7 @@ The file contains a series of key-value pairs, where the keys correspond to oper
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Optional
import torch import torch
...@@ -24,7 +25,12 @@ logger = logging.getLogger(__name__) ...@@ -24,7 +25,12 @@ logger = logging.getLogger(__name__)
class TensorDumper: class TensorDumper:
def __init__( def __init__(
self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int self,
dump_dir: str,
dump_layers: Optional[List[int]],
tp_size: int,
tp_rank: int,
pp_rank: int,
): ):
self._dump_layers = dump_layers self._dump_layers = dump_layers
self._forward_pass_id = 0 self._forward_pass_id = 0
...@@ -94,11 +100,15 @@ class TensorDumper: ...@@ -94,11 +100,15 @@ class TensorDumper:
top_level_model = True top_level_model = True
else: else:
cur_name = prefix + "." + name cur_name = prefix + "." + name
if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix: if (
self._dump_layers is not None
and name.isdigit()
and prefix == layers_prefix
):
# If we only need n layers, skip the reset layers. # If we only need n layers, skip the reset layers.
# Most models' layout is like model.layers.0. # Most models' layout is like model.layers.0.
cur_layer = int(name) cur_layer = int(name)
if cur_layer >= self._dump_layers: if cur_layer not in self._dump_layers:
continue continue
if module is not None: if module is not None:
_, sub_count = self._add_hook_recursive( _, sub_count = self._add_hook_recursive(
...@@ -129,7 +139,12 @@ class TensorDumper: ...@@ -129,7 +139,12 @@ class TensorDumper:
def register_forward_hook_for_model( def register_forward_hook_for_model(
model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int model,
dump_dir: str,
dump_layers: Optional[List[int]],
tp_size: int,
tp_rank: int,
pp_rank: int,
): ):
tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank) tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
# Most models have the layerout like: # Most models have the layerout like:
......
...@@ -519,8 +519,8 @@ class ServerArgs: ...@@ -519,8 +519,8 @@ class ServerArgs:
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
# -1 mean dump all layers. # None means dump all layers.
debug_tensor_dump_layers: int = -1 debug_tensor_dump_layers: Optional[List[int]] = None
# TODO(guoyuhong): clean the old dumper code. # TODO(guoyuhong): clean the old dumper code.
debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False debug_tensor_dump_inject: bool = False
...@@ -3424,8 +3424,8 @@ class ServerArgs: ...@@ -3424,8 +3424,8 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--debug-tensor-dump-layers", "--debug-tensor-dump-layers",
type=int, type=int,
default=-1, nargs="+",
help="The layer number for dumping tensors.", help="The layer ids to dump. Dump all layers if not specified.",
) )
parser.add_argument( parser.add_argument(
"--debug-tensor-dump-input-file", "--debug-tensor-dump-input-file",
......
...@@ -13,6 +13,7 @@ from sglang.srt.distributed.parallel_state import ( ...@@ -13,6 +13,7 @@ from sglang.srt.distributed.parallel_state import (
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.models.qwen2 import Qwen2MLP from sglang.srt.models.qwen2 import Qwen2MLP
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
TEST_HIDDEN_SIZE = 32 TEST_HIDDEN_SIZE = 32
...@@ -63,6 +64,7 @@ def init_weights(module): ...@@ -63,6 +64,7 @@ def init_weights(module):
def test_model_forward_dump(tmp_path): def test_model_forward_dump(tmp_path):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=1, world_size=1,
...@@ -75,7 +77,7 @@ def test_model_forward_dump(tmp_path): ...@@ -75,7 +77,7 @@ def test_model_forward_dump(tmp_path):
model.apply(init_weights) model.apply(init_weights)
model = model.cuda().bfloat16() model = model.cuda().bfloat16()
dumper = register_forward_hook_for_model( dumper = register_forward_hook_for_model(
model, tmp_path / "sglang_dump", -1, 0, 0, 0 model, tmp_path / "sglang_dump", [0], 0, 0, 0
) )
dir_path = dumper.get_dump_dir() dir_path = dumper.get_dump_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment