make_plot_with_jsonl.py 5.4 KB
Newer Older
Mitchell Wortsman's avatar
Mitchell Wortsman committed
1
import matplotlib.gridspec as gridspec
Aarni Koskela's avatar
Aarni Koskela committed
2
3
import matplotlib.pyplot as plt
import pandas as pd
Mitchell Wortsman's avatar
Mitchell Wortsman committed
4

Ruff's avatar
Ruff committed
5
cmap = plt.get_cmap("cool")
Mitchell Wortsman's avatar
Mitchell Wortsman committed
6

Ruff's avatar
Ruff committed
7
8
if __name__ == "__main__":
    fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
Mitchell Wortsman's avatar
Mitchell Wortsman committed
9
10
    gs = gridspec.GridSpec(1, 2)

Mitchell Wortsman's avatar
Mitchell Wortsman committed
11
12
13
14
15
    dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
    batch_size_for_plot1 = 32768
    batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
    dims_to_xtick = [1024, 2048, 4096]
    logscale_plot1 = True
Mitchell Wortsman's avatar
Mitchell Wortsman committed
16
17
18

    ax = fig.add_subplot(gs[0, 0])

Mitchell Wortsman's avatar
Mitchell Wortsman committed
19
    # TODO: change this to what you want.
Ruff's avatar
Ruff committed
20
    rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
21
    df = rdf[rdf.batch_size == batch_size_for_plot1]
Mitchell Wortsman's avatar
Mitchell Wortsman committed
22

Mitchell Wortsman's avatar
Mitchell Wortsman committed
23
    # first plot the time occupied by different operations
Mitchell Wortsman's avatar
Mitchell Wortsman committed
24
    for k, marker, ls, color, name in [
Ruff's avatar
Ruff committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
        (
            "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
            "o",
            "-",
            "C4",
            "SwitchBack int8 (sum of parts)",
        ),
        ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
        ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
        ("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
        ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
        ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
        ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
        ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
        ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
        ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
Mitchell Wortsman's avatar
Mitchell Wortsman committed
42
43
44
    ]:
        xs = []
        ys = []
Mitchell Wortsman's avatar
Mitchell Wortsman committed
45
46
        for embed_dim in dims_to_consider:
            # average over dim -> 4*dim and 4*dim -> dim
Mitchell Wortsman's avatar
Mitchell Wortsman committed
47
48
49
50
            df_ = df[df.dim_in == embed_dim]
            df_ = df_[df_.dim_out == embed_dim * 4]
            xs.append(embed_dim)
            y_ = 0
Ruff's avatar
Ruff committed
51
            for k_ in k.split("+"):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
52
53
54
                y_ += df_[k_].values[0]
            df_ = df[df.dim_in == embed_dim * 4]
            df_ = df_[df_.dim_out == embed_dim]
Ruff's avatar
Ruff committed
55
            for k_ in k.split("+"):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
56
57
58
                y_ += df_[k_].values[0]
            ys.append(y_ * 0.5)

Ruff's avatar
Ruff committed
59
60
61
62
63
64
65
66
67
68
        ax.plot(
            xs,
            ys,
            color=color,
            label=name,
            marker=marker,
            markersize=5 if marker == "s" else 5,
            linestyle=ls,
            linewidth=2 if "+" in k else 1.0,
        )
69

Ruff's avatar
Ruff committed
70
71
    ax.set_xlabel("dim", fontsize=13)
    ax.set_ylabel("time (ms)", fontsize=13)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
72
73
74

    ax.grid()

Ruff's avatar
Ruff committed
75
    ax.set_xscale("log")
Mitchell Wortsman's avatar
Mitchell Wortsman committed
76
    if logscale_plot1:
Ruff's avatar
Ruff committed
77
        ax.set_yscale("log")
78

Ruff's avatar
Ruff committed
79
80
    ax.tick_params(axis="x", labelsize=11)
    ax.tick_params(axis="y", labelsize=11)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
81

Mitchell Wortsman's avatar
Mitchell Wortsman committed
82
83
    ax.set_xticks(dims_to_xtick)
    ax.set_xticklabels(dims_to_xtick)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
84
85
    ax.set_xticks([], minor=True)

Ruff's avatar
Ruff committed
86
87
88
    leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
    leg.get_texts()[0].set_fontweight("bold")
    leg.get_texts()[1].set_fontweight("bold")
Mitchell Wortsman's avatar
Mitchell Wortsman committed
89
    plt.subplots_adjust(left=0.1)
Ruff's avatar
Ruff committed
90
    ax.set_title("  Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
91
92
93
94

    ax = fig.add_subplot(gs[0, 1])

    # now plot the % speedup for different batch sizes
Mitchell Wortsman's avatar
Mitchell Wortsman committed
95
    for j, batch_size in enumerate(batch_sizes_for_plot2):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
96
97
        all_xs, all_ys = [], []
        for k, marker, ls, color, name in [
Ruff's avatar
Ruff committed
98
99
100
101
102
103
104
105
            ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
            (
                "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
                "o",
                "-",
                "C4",
                "SwitchBack int8 (total time)",
            ),
Mitchell Wortsman's avatar
Mitchell Wortsman committed
106
107
108
        ]:
            xs, ys = [], []
            df = rdf[rdf.batch_size == batch_size]
Mitchell Wortsman's avatar
Mitchell Wortsman committed
109
            for embed_dim in dims_to_consider:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
110
111
112
113
                df_ = df[df.dim_in == embed_dim]
                df_ = df_[df_.dim_out == embed_dim * 4]
                xs.append(embed_dim)
                y_ = 0
Ruff's avatar
Ruff committed
114
                for k_ in k.split("+"):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
115
116
117
                    y_ += df_[k_].values[0]
                df_ = df[df.dim_in == embed_dim * 4]
                df_ = df_[df_.dim_out == embed_dim]
Ruff's avatar
Ruff committed
118
                for k_ in k.split("+"):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
119
120
121
122
123
124
125
                    y_ += df_[k_].values[0]
                ys.append(y_ * 0.5)
            all_xs.append(xs)
            all_ys.append(ys)

        color = cmap(j * 0.25)
        real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
Ruff's avatar
Ruff committed
126
127
128
129
130
131
132
133
134
        markers = ["^", "v", "P", "o"]
        ax.plot(
            all_xs[0],
            real_ys,
            color=color,
            label=f"batch * sequence length = {batch_size}",
            marker=markers[j],
            markersize=5 if marker == "s" else 5,
        )
Mitchell Wortsman's avatar
Mitchell Wortsman committed
135
136

    ax.legend()
Ruff's avatar
Ruff committed
137
138
    ax.set_xlabel("dim", fontsize=13)
    ax.set_xscale("log")
Mitchell Wortsman's avatar
Mitchell Wortsman committed
139
    ax.grid()
Ruff's avatar
Ruff committed
140
    ax.set_ylabel(r"% speedup", fontsize=13)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
141

Ruff's avatar
Ruff committed
142
143
    ax.tick_params(axis="x", labelsize=11)
    ax.tick_params(axis="y", labelsize=11)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
144

Mitchell Wortsman's avatar
Mitchell Wortsman committed
145
146
    ax.set_xticks(dims_to_xtick)
    ax.set_xticklabels(dims_to_xtick)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
147
148
    ax.set_xticks([], minor=True)

Ruff's avatar
Ruff committed
149
    ax.set_title("  Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
150

Ruff's avatar
Ruff committed
151
    plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")