Unverified Commit 9f00ec44 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix and enhance dumper (#8725)

parent 8e85ee88
import argparse import argparse
import functools import functools
import re
from pathlib import Path from pathlib import Path
import polars as pl import polars as pl
import torch import torch
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
from sglang.srt.debug_utils.dumper import get_truncated_value from sglang.srt.debug_utils.dumper import get_truncated_value
...@@ -26,66 +26,77 @@ def main(args): ...@@ -26,66 +26,77 @@ def main(args):
print("df_baseline", df_baseline) print("df_baseline", df_baseline)
for row in df_target.iter_rows(named=True): for row in df_target.iter_rows(named=True):
rows_baseline = df_baseline.filter( path_target = Path(args.target_path) / row["filename"]
(
pl.col("forward_pass_id") row_baseline = find_row(
== row["forward_pass_id"] - args.start_id + args.baseline_start_id df_baseline,
) conditions=dict(
& functools.reduce( forward_pass_id=row["forward_pass_id"]
lambda a, b: a & b, - args.start_id
[ + args.baseline_start_id,
pl.col(col) == row[col] **{
for col in row.keys() k: v
if col not in ["forward_pass_id", "dump_index", "filename"] for k, v in row.items()
], if k not in ["forward_pass_id", "dump_index", "filename"]
) },
),
) )
assert len(rows_baseline) == 1, f"{rows_baseline=}"
row_baseline = rows_baseline.to_dicts()[0] if row_baseline is None:
print(f"Skip: target={str(path_target)} since no baseline")
x_target = _load_object(path_target)
if x_target is not None:
print(f"x_target(sample)={get_truncated_value(x_target)}")
continue
path_baseline = Path(args.baseline_path) / row_baseline["filename"] path_baseline = Path(args.baseline_path) / row_baseline["filename"]
path_target = Path(args.target_path) / row["filename"]
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}") print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
check_tensor_pair(path_baseline=path_baseline, path_target=path_target) check_tensor_pair(
path_baseline=path_baseline, path_target=path_target, name=row["name"]
)
print() print()
def read_meta(directory): def check_tensor_pair(path_baseline, path_target, name=""):
directory = Path(directory) x_baseline = _load_object(path_baseline)
assert directory.is_dir(), f"{directory=} should be a directory" x_target = _load_object(path_target)
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
df = pl.DataFrame(rows) print(
df = df.with_columns( f"Raw "
pl.col("forward_pass_id").cast(int), f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
pl.col("rank").cast(int), f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
) )
return df
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
def check_tensor_pair(path_baseline, path_target): x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
x_baseline = torch.load(path_baseline, weights_only=True)
x_target = torch.load(path_target, weights_only=True)
print( print(
f"After preprocessor "
f"[shape] {x_baseline.shape} vs {x_target.shape}\t" f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
) )
x_target = x_target.float()
x_baseline = x_baseline.float()
for name, fn in (
("mean", torch.mean),
("std", torch.std),
("min", torch.min),
("max", torch.max),
("p1", functools.partial(torch.quantile, q=0.01)),
("p5", functools.partial(torch.quantile, q=0.05)),
("p95", functools.partial(torch.quantile, q=0.95)),
("p99", functools.partial(torch.quantile, q=0.99)),
):
value_baseline = fn(x_baseline).item()
value_target = fn(x_target).item()
print(
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
)
if x_baseline.shape != x_target.shape: if x_baseline.shape != x_target.shape:
print(f" Shape mismatch") print(f"⚠️ Shape mismatch")
return return
raw_abs_diff = (x_target - x_baseline).abs() raw_abs_diff = (x_target - x_baseline).abs()
...@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target): ...@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
print(f"x_target(sample)={get_truncated_value(x_target)}") print(f"x_target(sample)={get_truncated_value(x_target)}")
def _try_unify_shape(x: torch.Tensor, target_shape):
x_shape = x.shape
num_dim_to_remove = len(x_shape) - len(target_shape)
if (x_shape[num_dim_to_remove:] == target_shape) and all(
val == 1 for val in x_shape[:num_dim_to_remove]
):
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
return out
return x
# Copied from DeepGEMM # Copied from DeepGEMM
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double() x, y = x.double(), y.double()
...@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): ...@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim return 1 - sim
def _comparison_preprocessor(x_baseline, x_target, name):
# can insert arbitrary adhoc postprocessing logic here
return x_baseline, x_target
def _load_object(path):
x = torch.load(path, weights_only=False)
if not isinstance(x, torch.Tensor):
print(f"Skip load {path} since {type(x)=} is not a Tensor")
return None
return x.cuda()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--baseline-path", type=str) parser.add_argument("--baseline-path", type=str)
......
import functools
import os
from pathlib import Path
from typing import Any, Dict
import polars as pl
import torch
class DumpLoader:
def __init__(self):
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
self._enable = directory is not None
if self._enable:
self._directory = Path(directory)
self._df = read_meta(directory)
@property
def enable(self):
return self._enable
def load(self, name, **kwargs):
assert self._enable, "Please call DumpLoader.load only when it is enabled"
from sglang.srt.debug_utils.dumper import dumper
forward_pass_id = dumper._forward_pass_id
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
row = find_row(self._df, conditions=conditions)
assert (
row is not None
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
path = self._directory / row["filename"]
output = torch.load(path, weights_only=False)
print(
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
)
return output
def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
pl.col("dump_index").cast(int),
)
return df
def find_row(df, conditions: Dict[str, Any]):
df_sub = df.filter(
functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
for col in conditions.keys()
],
)
)
assert len(df_sub) <= 1
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
def _cast_to_polars_dtype(value, target_dtype):
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
return int(value)
elif target_dtype in (pl.Float64, pl.Float32):
return float(value)
elif target_dtype == pl.Boolean:
return bool(value)
elif target_dtype == pl.String:
return str(value)
else:
return value
dump_loader = DumpLoader()
...@@ -53,7 +53,7 @@ class _Dumper: ...@@ -53,7 +53,7 @@ class _Dumper:
if self._partial_name is None: if self._partial_name is None:
self._partial_name = _get_partial_name() self._partial_name = _get_partial_name()
rank = dist.get_rank() rank = _get_rank()
full_kwargs = dict( full_kwargs = dict(
forward_pass_id=self._forward_pass_id, forward_pass_id=self._forward_pass_id,
rank=rank, rank=rank,
...@@ -80,12 +80,20 @@ class _Dumper: ...@@ -80,12 +80,20 @@ class _Dumper:
def _get_partial_name(): def _get_partial_name():
rank = dist.get_rank() rank = _get_rank()
object_list = [str(time.time()) if rank == 0 else None] object_list = [str(time.time()) if rank == 0 else None]
if dist.is_initialized():
dist.broadcast_object_list(object_list, device="cuda") dist.broadcast_object_list(object_list, device="cuda")
return object_list[0] return object_list[0]
def _get_rank():
if dist.is_initialized():
return dist.get_rank()
else:
return 0
def get_truncated_value(value): def get_truncated_value(value):
if value is None: if value is None:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment