Unverified Commit ed0fdbf3 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tool to dump and compare internal activation tensors (#7976)

parent b602f423
import argparse
import functools
import re
from pathlib import Path
import polars as pl
import torch
from sglang.srt.debug_utils.dumper import get_truncated_value
def main(args):
df_target = read_meta(args.target_path)
df_target = df_target.sort("rank", "dump_index")
df_target = df_target.filter(
(pl.col("forward_pass_id") >= args.start_id)
& (pl.col("forward_pass_id") <= args.end_id)
)
assert all(
c in df_target.columns
for c in ["rank", "forward_pass_id", "dump_index", "name"]
)
df_baseline = read_meta(args.baseline_path)
print("df_target", df_target)
print("df_baseline", df_baseline)
for row in df_target.iter_rows(named=True):
rows_baseline = df_baseline.filter(
(
pl.col("forward_pass_id")
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
)
& functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == row[col]
for col in row.keys()
if col not in ["forward_pass_id", "dump_index", "filename"]
],
)
)
assert len(rows_baseline) == 1, f"{rows_baseline=}"
row_baseline = rows_baseline.to_dicts()[0]
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
path_target = Path(args.target_path) / row["filename"]
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
print()
def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
)
return df
def check_tensor_pair(path_baseline, path_target):
x_baseline = torch.load(path_baseline, weights_only=True)
x_target = torch.load(path_target, weights_only=True)
print(
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)
if x_baseline.shape != x_target.shape:
print(f"❌ Shape mismatch")
return
raw_abs_diff = (x_target - x_baseline).abs()
max_abs_diff = raw_abs_diff.max().item()
mean_abs_diff = raw_abs_diff.mean().item()
rel_diff = _calc_rel_diff(x_target, x_baseline)
needs_print = max_abs_diff > 1e-3
print(
"\t".join(
f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
for name, value in [
("rel_diff", rel_diff),
("max_abs_diff", max_abs_diff),
("mean_abs_diff", mean_abs_diff),
]
)
)
if needs_print:
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
print(f"x_target(sample)={get_truncated_value(x_target)}")
# Copied from DeepGEMM
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--baseline-path", type=str)
parser.add_argument("--target-path", type=str)
parser.add_argument("--start-id", type=int, default=0)
parser.add_argument("--end-id", type=int, default=1000000)
parser.add_argument("--baseline-start-id", type=int, default=0)
args = parser.parse_args()
main(args)
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
import torch.distributed as dist
from sglang.srt.utils import get_bool_env_var
class _Dumper: class _Dumper:
"""Utility to dump tensors, which can be useful when comparison checking models. """Utility to dump tensors, which can be useful when comparison checking models.
Example usage: Example usage:
debug_utils.dumper.dump("layer_start_hidden_states", hidden_states, layer_id=self.layer_id) dumper.on_forward_pass_start()
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
Import from non-SGLang system:
```
import sys
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
from dumper import dumper
```
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
""" """
def __init__(self): def __init__(self):
self._enable = get_bool_env_var("SGLANG_DUMPER_ENABLE", "true") # Do not import `sglang` to make this file standalone
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp")) self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
self._enable_write_file = get_bool_env_var("SGLANG_DUMPER_WRITE_FILE", "1") self._enable_write_file = bool(
self._partial_name = str(time.time()) int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
self.forward_pass_id = None )
self._partial_name: Optional[str] = None
self._dump_index = 0
self._forward_pass_id = 0
def on_forward_pass_start(self):
self._forward_pass_id += 1
print(
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
)
def dump(self, name, value, **kwargs): def dump(self, name, value, **kwargs):
if not self._enable: if not self._enable:
return return
from sglang.srt.distributed import get_tensor_model_parallel_rank assert (
self._forward_pass_id >= 1
), "Do you forget to call `dumper.on_forward_pass_start()`?"
self._dump_index += 1
if self._partial_name is None:
self._partial_name = _get_partial_name()
rank = get_tensor_model_parallel_rank() rank = dist.get_rank()
full_kwargs = dict( full_kwargs = dict(
forward_pass_id=self.forward_pass_id, forward_pass_id=self._forward_pass_id,
rank=rank,
name=name, name=name,
dump_index=self._dump_index,
**kwargs, **kwargs,
) )
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt" full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
path = ( path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
self._base_dir / f"sglang_dump_{self._partial_name}_{rank}" / full_filename
)
sample_value = self._get_sample_value(name, value) sample_value = get_truncated_value(value)
print( print(
f"[{rank}, {time.time()}] {path} " f"[Dumper] [{rank}, {time.time()}] {path} "
f"type={type(value)} " f"type={type(value)} "
f"shape={value.shape if isinstance(value, torch.Tensor) else None} " f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} " f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
...@@ -52,23 +78,31 @@ class _Dumper: ...@@ -52,23 +78,31 @@ class _Dumper:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
torch.save(value, str(path)) torch.save(value, str(path))
def _get_sample_value(self, name, value):
if value is None:
return None
if isinstance(value, tuple): def _get_partial_name():
return [self._get_sample_value(name, x) for x in value] rank = dist.get_rank()
object_list = [str(time.time()) if rank == 0 else None]
dist.broadcast_object_list(object_list, device="cuda")
return object_list[0]
def get_truncated_value(value):
if value is None:
return None
if isinstance(value, tuple):
return [get_truncated_value(x) for x in value]
if not isinstance(value, torch.Tensor): if not isinstance(value, torch.Tensor):
return None return None
if value.numel() < 200: if value.numel() < 200:
return value return value
slices = [ slices = [
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
] ]
return value[tuple(slices)] return value[tuple(slices)]
dumper = _Dumper() dumper = _Dumper()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment