utils.py 4.15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import argparse
import json
import math
import os
8
from contextlib import contextmanager
9
10
11
from typing import Any


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def extract_field(
    args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
) -> str:
    if field_name in extra_info:
        return extra_info[field_name]

    v = args
    # For example, args.compilation_config.mode
    for nested_field in field_name.split("."):
        if not hasattr(v, nested_field):
            return ""
        v = getattr(v, nested_field)
    return v


def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
    """
    Check if the benchmark is run with torch.compile
    """
    return not (
        extract_field(args, extra_info, "compilation_config.mode") == "0"
        or "eager" in getattr(args, "output_json", "")
        or "eager" in getattr(args, "result_filename", "")
    )


38
39
40
def convert_to_pytorch_benchmark_format(
    args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
41
42
43
44
45
46
47
48
49
50
    """
    Save the benchmark results in the format used by PyTorch OSS benchmark with
    on metric per record
    https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
    """
    records = []
    if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
        return records

    for name, benchmark_values in metrics.items():
51
52
53
54
55
56
        if not isinstance(benchmark_values, list):
            raise TypeError(
                f"benchmark_values for metric '{name}' must be a list, "
                f"but got {type(benchmark_values).__name__}"
            )

57
58
59
60
61
        record = {
            "benchmark": {
                "name": "vLLM benchmark",
                "extra_info": {
                    "args": vars(args),
62
63
64
65
66
67
68
69
                    "compilation_config.mode": extract_field(
                        args, extra_info, "compilation_config.mode"
                    ),
                    "optimization_level": extract_field(
                        args, extra_info, "optimization_level"
                    ),
                    # A boolean field used by vLLM benchmark HUD dashboard
                    "use_compile": use_compile(args, extra_info),
70
71
72
73
74
75
76
77
78
79
80
81
                },
            },
            "model": {
                "name": args.model,
            },
            "metric": {
                "name": name,
                "benchmark_values": benchmark_values,
                "extra_info": extra_info,
            },
        }

82
        tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
83
84
        # Save tensor_parallel_size parameter if it's part of the metadata
        if not tp and "tensor_parallel_size" in extra_info:
85
86
87
            record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
                extra_info["tensor_parallel_size"]
            )
88
89
90
91
92
93
94
95
96

        records.append(record)

    return records


class InfEncoder(json.JSONEncoder):
    def clear_inf(self, o: Any):
        if isinstance(o, dict):
97
98
99
100
101
102
            return {
                str(k)
                if not isinstance(k, (str, int, float, bool, type(None)))
                else k: self.clear_inf(v)
                for k, v in o.items()
            }
103
104
105
106
107
108
109
110
111
112
113
114
        elif isinstance(o, list):
            return [self.clear_inf(v) for v in o]
        elif isinstance(o, float) and math.isinf(o):
            return "inf"
        return o

    def iterencode(self, o: Any, *args, **kwargs) -> Any:
        return super().iterencode(self.clear_inf(o), *args, **kwargs)


def write_to_json(filename: str, records: list) -> None:
    with open(filename, "w") as f:
115
116
117
118
119
120
        json.dump(
            records,
            f,
            cls=InfEncoder,
            default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
        )
121
122
123
124
125
126
127
128
129
130
131


@contextmanager
def default_vllm_config():
    """Set a default VllmConfig for cases that directly test CustomOps or pathways
    that use get_current_vllm_config() outside of a full engine context.
    """
    from vllm.config import VllmConfig, set_current_vllm_config

    with set_current_vllm_config(VllmConfig()):
        yield