"backend/vscode:/vscode.git/clone" did not exist on "ae97a96379ed85d14e2b70cbd5ede233f1178534"
dump_loader.py 2.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import functools
import os
from pathlib import Path
from typing import Any, Dict

import polars as pl
import torch


class DumpLoader:
    def __init__(self):
        directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")

        self._enable = directory is not None
        if self._enable:
            self._directory = Path(directory)
            self._df = read_meta(directory)

    @property
    def enable(self):
        return self._enable

    def load(self, name, **kwargs):
        assert self._enable, "Please call DumpLoader.load only when it is enabled"

        from sglang.srt.debug_utils.dumper import dumper

        forward_pass_id = dumper._forward_pass_id
        conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
        row = find_row(self._df, conditions=conditions)
        assert (
            row is not None
        ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"

        path = self._directory / row["filename"]
        output = torch.load(path, weights_only=False)

        print(
            f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
        )
        return output


def read_meta(directory):
    directory = Path(directory)
    assert directory.is_dir(), f"{directory=} should be a directory"

    rows = []
    for p in directory.glob("*.pt"):
        full_kwargs = {}
        for kv in p.stem.split("___"):
            k, v = kv.split("=")
            full_kwargs[k] = v
        rows.append(
            {
                "filename": str(p.name),
                **full_kwargs,
            }
        )

    df = pl.DataFrame(rows)
    df = df.with_columns(
        pl.col("forward_pass_id").cast(int),
        pl.col("rank").cast(int),
        pl.col("dump_index").cast(int),
    )
    return df


def find_row(df, conditions: Dict[str, Any]):
    df_sub = df.filter(
        functools.reduce(
            lambda a, b: a & b,
            [
                pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
                for col in conditions.keys()
            ],
        )
    )
    assert len(df_sub) <= 1
    return df_sub.to_dicts()[0] if len(df_sub) > 0 else None


def _cast_to_polars_dtype(value, target_dtype):
    if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
        return int(value)
    elif target_dtype in (pl.Float64, pl.Float32):
        return float(value)
    elif target_dtype == pl.Boolean:
        return bool(value)
    elif target_dtype == pl.String:
        return str(value)
    else:
        return value


dump_loader = DumpLoader()