graph_machete_bench.py 1.95 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

zhuwenwen's avatar
zhuwenwen committed
4
5
6
7
8
9
import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
10
import regex as re
zhuwenwen's avatar
zhuwenwen committed
11
12
13
14
15
16
17
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from vllm.utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
18
19
20
21
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)
zhuwenwen's avatar
zhuwenwen committed
22
23
24

    args = parser.parse_args()

25
26
27
    with open(args.filename, "rb") as f:
        data = pickle.load(f)
        raw_results: list[TMeasurement] = data["results"]
zhuwenwen's avatar
zhuwenwen committed
28
29

    results = defaultdict(lambda: list())
30
    for v in raw_results:
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
35
36
37
38
39
40
41
42
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
43
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
zhuwenwen's avatar
zhuwenwen committed
44
45
46
47
48
49
50

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
    for axs_idx, (shape, data) in enumerate(results.items()):
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
51
52
53
54
55
56
57
58
59
60
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
zhuwenwen's avatar
zhuwenwen committed
61
62
63
64
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")