Unverified Commit 9f00ec44 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix and enhance dumper (#8725)

parent 8e85ee88
import argparse
import functools
import re
from pathlib import Path
import polars as pl
import torch
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
from sglang.srt.debug_utils.dumper import get_truncated_value
......@@ -26,66 +26,77 @@ def main(args):
print("df_baseline", df_baseline)
for row in df_target.iter_rows(named=True):
rows_baseline = df_baseline.filter(
(
pl.col("forward_pass_id")
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
)
& functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == row[col]
for col in row.keys()
if col not in ["forward_pass_id", "dump_index", "filename"]
],
)
path_target = Path(args.target_path) / row["filename"]
row_baseline = find_row(
df_baseline,
conditions=dict(
forward_pass_id=row["forward_pass_id"]
- args.start_id
+ args.baseline_start_id,
**{
k: v
for k, v in row.items()
if k not in ["forward_pass_id", "dump_index", "filename"]
},
),
)
assert len(rows_baseline) == 1, f"{rows_baseline=}"
row_baseline = rows_baseline.to_dicts()[0]
if row_baseline is None:
print(f"Skip: target={str(path_target)} since no baseline")
x_target = _load_object(path_target)
if x_target is not None:
print(f"x_target(sample)={get_truncated_value(x_target)}")
continue
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
path_target = Path(args.target_path) / row["filename"]
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
check_tensor_pair(
path_baseline=path_baseline, path_target=path_target, name=row["name"]
)
print()
def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
def check_tensor_pair(path_baseline, path_target, name=""):
x_baseline = _load_object(path_baseline)
x_target = _load_object(path_target)
df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
print(
f"Raw "
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)
return df
def check_tensor_pair(path_baseline, path_target):
x_baseline = torch.load(path_baseline, weights_only=True)
x_target = torch.load(path_target, weights_only=True)
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
print(
f"After preprocessor "
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)
x_target = x_target.float()
x_baseline = x_baseline.float()
for name, fn in (
("mean", torch.mean),
("std", torch.std),
("min", torch.min),
("max", torch.max),
("p1", functools.partial(torch.quantile, q=0.01)),
("p5", functools.partial(torch.quantile, q=0.05)),
("p95", functools.partial(torch.quantile, q=0.95)),
("p99", functools.partial(torch.quantile, q=0.99)),
):
value_baseline = fn(x_baseline).item()
value_target = fn(x_target).item()
print(
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
)
if x_baseline.shape != x_target.shape:
print(f" Shape mismatch")
print(f"⚠️ Shape mismatch")
return
raw_abs_diff = (x_target - x_baseline).abs()
......@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
print(f"x_target(sample)={get_truncated_value(x_target)}")
def _try_unify_shape(x: torch.Tensor, target_shape):
x_shape = x.shape
num_dim_to_remove = len(x_shape) - len(target_shape)
if (x_shape[num_dim_to_remove:] == target_shape) and all(
val == 1 for val in x_shape[:num_dim_to_remove]
):
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
return out
return x
# Copied from DeepGEMM
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
......@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim
def _comparison_preprocessor(x_baseline, x_target, name):
# can insert arbitrary adhoc postprocessing logic here
return x_baseline, x_target
def _load_object(path):
x = torch.load(path, weights_only=False)
if not isinstance(x, torch.Tensor):
print(f"Skip load {path} since {type(x)=} is not a Tensor")
return None
return x.cuda()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--baseline-path", type=str)
......
import functools
import os
from pathlib import Path
from typing import Any, Dict
import polars as pl
import torch
class DumpLoader:
def __init__(self):
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
self._enable = directory is not None
if self._enable:
self._directory = Path(directory)
self._df = read_meta(directory)
@property
def enable(self):
return self._enable
def load(self, name, **kwargs):
assert self._enable, "Please call DumpLoader.load only when it is enabled"
from sglang.srt.debug_utils.dumper import dumper
forward_pass_id = dumper._forward_pass_id
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
row = find_row(self._df, conditions=conditions)
assert (
row is not None
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
path = self._directory / row["filename"]
output = torch.load(path, weights_only=False)
print(
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
)
return output
def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
pl.col("dump_index").cast(int),
)
return df
def find_row(df, conditions: Dict[str, Any]):
df_sub = df.filter(
functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
for col in conditions.keys()
],
)
)
assert len(df_sub) <= 1
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
def _cast_to_polars_dtype(value, target_dtype):
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
return int(value)
elif target_dtype in (pl.Float64, pl.Float32):
return float(value)
elif target_dtype == pl.Boolean:
return bool(value)
elif target_dtype == pl.String:
return str(value)
else:
return value
dump_loader = DumpLoader()
......@@ -53,7 +53,7 @@ class _Dumper:
if self._partial_name is None:
self._partial_name = _get_partial_name()
rank = dist.get_rank()
rank = _get_rank()
full_kwargs = dict(
forward_pass_id=self._forward_pass_id,
rank=rank,
......@@ -80,12 +80,20 @@ class _Dumper:
def _get_partial_name():
rank = dist.get_rank()
rank = _get_rank()
object_list = [str(time.time()) if rank == 0 else None]
dist.broadcast_object_list(object_list, device="cuda")
if dist.is_initialized():
dist.broadcast_object_list(object_list, device="cuda")
return object_list[0]
def _get_rank():
if dist.is_initialized():
return dist.get_rank()
else:
return 0
def get_truncated_value(value):
if value is None:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment