benchmark_polynorm.py 3.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import itertools

import torch

from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton


def polynorm_naive(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float = 1e-6,
):
    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])

    def norm(x, eps: float):
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)

    x = x.float()
    return (
        (
            weight[0] * norm(x**3, eps)
            + weight[1] * norm(x**2, eps)
            + weight[2] * norm(x, eps)
            + bias
        )
        .to(weight.dtype)
        .view(orig_shape)
    )


def polynorm_vllm(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float = 1e-6,
):
    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])

    out = torch.empty_like(x)
    vllm_ops.poly_norm(out, x, weight, bias, eps)
    output = out

    output = output.view(orig_shape)
    return output


def calculate_diff(batch_size, seq_len, hidden_dim):
    dtype = torch.bfloat16
    x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
    weight = torch.ones(3, dtype=dtype, device="cuda")
    bias = torch.ones(1, dtype=dtype, device="cuda")

    output_naive = polynorm_naive(x, weight, bias)
    output_vllm = polynorm_vllm(x, weight, bias)

    if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
        print("✅ All implementations match")
    else:
        print("❌ Implementations differ")


batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))


def get_benchmark():
    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["dim", "batch_size", "seq_len"],
            x_vals=[list(_) for _ in configs],
            line_arg="provider",
            line_vals=["naive", "vllm"],
            line_names=["Naive", "vLLM"],
            styles=[("blue", "-"), ("red", "-")],
            ylabel="us",
            plot_name="polynorm-perf",
            args={},
        )
    )
    def benchmark(dim, batch_size, seq_len, provider):
        dtype = torch.bfloat16
        hidden_dim = dim * 4

        x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
        weight = torch.ones(3, dtype=dtype, device="cuda")
        bias = torch.ones(1, dtype=dtype, device="cuda")

        quantiles = [0.5, 0.2, 0.8]

        if provider == "naive":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: polynorm_naive(x, weight, bias),
                quantiles=quantiles,
            )
        else:
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: polynorm_vllm(x, weight, bias),
                quantiles=quantiles,
            )

        return 1000 * ms, 1000 * max_ms, 1000 * min_ms

    return benchmark


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--batch-size",
        type=int,
        default=4,
        help="Batch size",
    )
    parser.add_argument(
        "--seq-len",
        type=int,
        default=128,
        help="Sequence length",
    )
    parser.add_argument(
        "--hidden-dim",
        type=int,
        default=8192,
        help="Intermediate size of MLP",
    )
    parser.add_argument(
        "--save-path",
        type=str,
        default="./configs/polnorm/",
        help="Path to save polnorm benchmark results",
    )

    args = parser.parse_args()

    # Run correctness test
    calculate_diff(
        batch_size=args.batch_size,
        seq_len=args.seq_len,
        hidden_dim=args.hidden_dim,
    )

    benchmark = get_benchmark()
    # Run performance benchmark
    benchmark.run(print_data=True, save_path=args.save_path)