plot_benchmarks.py 5.29 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import json
import click
import numpy as np

import matplotlib as mpl
import seaborn as sns

mpl.use("agg")

import matplotlib.pyplot as plt  # noqa: E402

sns.set_style("ticks")

COMPONENT_COLORS = {
    "numpy": "orangered",
    "numpy-mpi": "coral",
    "jax": "dodgerblue",
    "jax-mpi": "steelblue",
    "jax-gpu": "teal",
    "fortran": "0.4",
    "fortran-mpi": "0.2",
}


@click.argument("INFILES", nargs=-1, type=click.Path(dir_okay=False, exists=True))
@click.option("--xaxis", type=click.Choice(["nproc", "size"]), required=True)
@click.option("--norm-component", default=None)
@click.command()
def plot_benchmarks(infiles, xaxis, norm_component):
    benchmarks = set()
    components = set()
    sizes = set()
    nprocs = set()

    for infile in infiles:
        with open(infile) as f:
            data = json.load(f)

        meta = data["settings"]
        benchmarks |= set(meta["only"])
        components |= set(meta["components"])
        sizes |= set(meta["sizes"])
        nprocs.add(meta["nproc"])

    if xaxis == "nproc":
        assert len(sizes) == 1
        xvals = np.array(sorted(nprocs))
    elif xaxis == "size":
        assert len(nprocs) == 1
        xvals = np.array(sorted(sizes))
    else:
        assert False

    if norm_component is not None and norm_component not in components:
        raise ValueError(f"Did not find norm component {norm_component} in data")

    component_data = {benchmark: {comp: np.full(len(xvals), np.nan) for comp in components} for benchmark in benchmarks}

    for infile in infiles:
        with open(infile) as f:
            data = json.load(f)

        for benchmark, bench_res in data["benchmarks"].items():
            for res in bench_res:
                if xaxis == "size":
                    # sizes are approximate, take the closest one
                    x_idx = np.argmin(np.abs(np.array(xvals) - res["size"]))
                else:
                    x_idx = xvals.tolist().index(data["settings"]["nproc"])

                time = float(res["per_iteration"]["mean"])
                component_data[benchmark][res["component"]][x_idx] = time

    for benchmark in benchmarks:
        fig, ax = plt.subplots(1, 1, figsize=(5.5, 4), dpi=150)

        last_coords = {}
        for component in components:
            if norm_component:
                # compute rel. speedup
                yvals = component_data[benchmark][norm_component] / component_data[benchmark][component]
            else:
                yvals = component_data[benchmark][component]

            plt.plot(xvals, yvals, ".--", color=COMPONENT_COLORS[component], lw=1)

            finite_mask = np.isfinite(yvals)
            if finite_mask.any():
                last_coords[component] = (xvals[finite_mask][-1], yvals[finite_mask][-1])
            else:
                last_coords[component] = (xvals[0], 1)

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        title_kwargs = dict(fontdict=dict(weight="bold", size=11), ha="left", x=0.05, y=1.05)
        if xaxis == "nproc":
            plt.xlabel("Number of MPI processes")
            mantissa, exponent = f"{list(sizes)[0]:.1e}".split("e")
            exponent = exponent.lstrip("+0")
            plt.title(f"Benchmark '{benchmark}' for size {mantissa} $\\times$ 10$^{{{exponent}}}$", **title_kwargs)

        elif xaxis == "size":
            nproc = list(nprocs)[0]
            plt.xlabel("Problem size (# elements)")
            plt.title(f"Benchmark '{benchmark}' on {nproc} processes", **title_kwargs)

            if norm_component:
                plt.axhline(nproc, linestyle="dashed", alpha=0.5, lw=1, color="C0")
                plt.annotate(
                    "Perfect CPU scaling",
                    (min(xvals), nproc),
                    xytext=(0, -2),
                    textcoords="offset points",
                    alpha=0.5,
                    color="C0",
                    va="top",
                    fontsize=8,
                )

        if norm_component:
            plt.ylabel("Relative speedup")
            plt.text(0.05, 1.05, "Speedup (higher is better)", transform=ax.transAxes, va="top", color="0.4")
        else:
            plt.ylabel("Time per iteration (s)")
            plt.text(0.05, 1.05, "Wall time (lower is better)", transform=ax.transAxes, va="top", color="0.4")

        plt.xscale("log")
        plt.yscale("log")

        fig.canvas.draw()

        # add annotations, make sure they don"t overlap
        last_text_pos = 0
        for component, (x, y) in sorted(last_coords.items(), key=lambda k: k[1][1]):
            trans = ax.transData
            _, tp = trans.transform((0, y))
            tp = max(tp, last_text_pos + 20)
            _, y = trans.inverted().transform((0, tp))

            plt.annotate(
                component,
                (x, y),
                xytext=(10, 0),
                textcoords="offset points",
                annotation_clip=False,
                color=COMPONENT_COLORS[component],
                va="center",
                weight="bold",
            )

            last_text_pos = tp

        fig.tight_layout()

        suffix = ""
        if norm_component:
            suffix = f"-norm_{norm_component}"

        fig.savefig(f"{benchmark}-{xaxis}{suffix}.png")
        plt.close(fig)


if __name__ == "__main__":
    plot_benchmarks()