graph_machete_bench.py 1.96 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
import regex as re
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from vllm.utils.argparse_utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)

    args = parser.parse_args()

    with open(args.filename, "rb") as f:
        data = pickle.load(f)
        raw_results: list[TMeasurement] = data["results"]

    results = defaultdict(lambda: list())
    for v in raw_results:
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
    for axs_idx, (shape, data) in enumerate(results.items()):
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")