utils.py 3.78 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import argparse
import json
import math
import os
from typing import Any


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def extract_field(
    args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
) -> str:
    if field_name in extra_info:
        return extra_info[field_name]

    v = args
    # For example, args.compilation_config.mode
    for nested_field in field_name.split("."):
        if not hasattr(v, nested_field):
            return ""
        v = getattr(v, nested_field)
    return v


def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
    """
    Check if the benchmark is run with torch.compile
    """
    return not (
        extract_field(args, extra_info, "compilation_config.mode") == "0"
        or "eager" in getattr(args, "output_json", "")
        or "eager" in getattr(args, "result_filename", "")
    )


37
38
39
def convert_to_pytorch_benchmark_format(
    args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
40
41
42
43
44
45
46
47
48
49
    """
    Save the benchmark results in the format used by PyTorch OSS benchmark with
    on metric per record
    https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
    """
    records = []
    if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
        return records

    for name, benchmark_values in metrics.items():
50
51
52
53
54
55
        if not isinstance(benchmark_values, list):
            raise TypeError(
                f"benchmark_values for metric '{name}' must be a list, "
                f"but got {type(benchmark_values).__name__}"
            )

56
57
58
59
60
        record = {
            "benchmark": {
                "name": "vLLM benchmark",
                "extra_info": {
                    "args": vars(args),
61
62
63
64
65
66
67
68
                    "compilation_config.mode": extract_field(
                        args, extra_info, "compilation_config.mode"
                    ),
                    "optimization_level": extract_field(
                        args, extra_info, "optimization_level"
                    ),
                    # A boolean field used by vLLM benchmark HUD dashboard
                    "use_compile": use_compile(args, extra_info),
69
70
71
72
73
74
75
76
77
78
79
80
                },
            },
            "model": {
                "name": args.model,
            },
            "metric": {
                "name": name,
                "benchmark_values": benchmark_values,
                "extra_info": extra_info,
            },
        }

81
        tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
82
83
        # Save tensor_parallel_size parameter if it's part of the metadata
        if not tp and "tensor_parallel_size" in extra_info:
84
85
86
            record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
                extra_info["tensor_parallel_size"]
            )
87
88
89
90
91
92
93
94
95

        records.append(record)

    return records


class InfEncoder(json.JSONEncoder):
    def clear_inf(self, o: Any):
        if isinstance(o, dict):
96
97
98
99
100
101
            return {
                str(k)
                if not isinstance(k, (str, int, float, bool, type(None)))
                else k: self.clear_inf(v)
                for k, v in o.items()
            }
102
103
104
105
106
107
108
109
110
111
112
113
        elif isinstance(o, list):
            return [self.clear_inf(v) for v in o]
        elif isinstance(o, float) and math.isinf(o):
            return "inf"
        return o

    def iterencode(self, o: Any, *args, **kwargs) -> Any:
        return super().iterencode(self.clear_inf(o), *args, **kwargs)


def write_to_json(filename: str, records: list) -> None:
    with open(filename, "w") as f:
114
115
116
117
118
119
        json.dump(
            records,
            f,
            cls=InfEncoder,
            default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
        )