"examples/community/test_onnx_controlnet.py" did not exist on "7447f75b9f8badb073636ed163417b0947c59e9f"
bench_lightning_attention_decode.py 8.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import itertools
import math

import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode


def next_power_of_2(n):
    return 2 ** (int(math.ceil(math.log(n, 2))))


@triton.jit
def _decode_kernel(
    Q,
    K,
    V,
    KV,
    Out,
    S,
    b: tl.constexpr,
    h: tl.constexpr,
    n: tl.constexpr,
    d: tl.constexpr,
    d_original: tl.constexpr,
    e: tl.constexpr,
    e_original: tl.constexpr,
):
    off_bh = tl.program_id(0)
    off_h = off_bh % h

    qk_offset = off_bh * n * d
    v_offset = off_bh * n * e
    o_offset = off_bh * n * e
    kv_offset = off_bh * d * e

    s = tl.load(S + off_h)
    ratio = tl.exp(-s)

    d_idx = tl.arange(0, d)
    e_idx = tl.arange(0, e)

    # Create masks for original dimensions
    d_mask = d_idx < d_original
    e_mask = e_idx < e_original

    # Load with masking
    q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
    k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
    v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)

    # Load KV with 2D masking
    kv = tl.load(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        mask=(d_mask[:, None] & e_mask[None, :]),
        other=0.0,
    )

    # Compute outer product using element-wise operations
    k_v_prod = k[:, None] * v[None, :]
    kv = ratio * kv + k_v_prod

    # Store KV with 2D masking
    tl.store(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        kv.to(KV.dtype.element_ty),
        mask=(d_mask[:, None] & e_mask[None, :]),
    )

    # Compute matrix-vector multiplication using element-wise operations and reduction
    o = tl.sum(q[:, None] * kv, axis=0)

    # Store output with masking
    tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)


def triton_lightning_attn_decode(q, k, v, kv, s):
    """Triton implementation of Lightning Attention decode operation"""
    b, h, n, d = q.shape
    e = v.shape[-1]
    assert n == 1, "Sequence length must be 1 in decode mode"

    # Get padded dimensions (power of 2)
    d_padded = next_power_of_2(d)
    e_padded = next_power_of_2(e)

    # Create output tensor (padded)
    o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)

    # Create padded tensors without actually padding the data
    q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
    k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
    v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
    kv_padded = torch.empty(
        b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
    )

    # Copy data to padded tensors
    q_padded[..., :d] = q
    k_padded[..., :d] = k
    v_padded[..., :e] = v
    kv_padded[..., :d, :e] = kv

    # Launch kernel
    grid = (b * h, 1)
    _decode_kernel[grid](
        q_padded,
        k_padded,
        v_padded,
        kv_padded,
        o_padded,
        s,
        b=b,
        h=h,
        n=n,
        d=d_padded,
        d_original=d,
        e=e_padded,
        e_original=e,
    )

    # Get unpadded outputs
    o = o_padded[..., :e]
    kv_out = kv_padded[..., :d, :e]

    return o, kv_out


def lightning_attention_decode_naive(q, k, v, past_kv, slope):
    """Naive implementation of lightning attention decode"""
    original_dtype = q.dtype
    ratio = torch.exp(-slope)  # [h, 1, 1]

    kv = past_kv
    b, h, n, d = q.shape

    output = []
    for i in range(n):
        kv = ratio * kv.to(torch.float32) + torch.einsum(
            "... n d, ... n e -> ... d e",
            k[:, :, i : i + 1],
            v[:, :, i : i + 1],
        )
        qkv = torch.einsum(
            "... n e, ... e d -> ... n d",
            q[:, :, i : i + 1].to(torch.float32),
            kv.to(torch.float32),
        )
        output.append(qkv)
151
    output = torch.cat(output, dim=-2)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    return output.to(original_dtype), kv


def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
    return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)


def calculate_diff(batch_size):
    dtype = torch.bfloat16
    device = torch.device("cuda")
    num_heads = 64
    head_dim = 96
    seq_len = 1

    q = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    k = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    v = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
    slope = torch.randn(num_heads, 1, 1, device=device)

    output_naive, new_kv_naive = lightning_attention_decode_naive(
        q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
    )

    output_kernel = torch.empty_like(output_naive)
    new_kv_kernel = torch.empty_like(new_kv_naive)
    lightning_attention_decode_kernel(
        q.clone(),
        k.clone(),
        v.clone(),
        past_kv.clone(),
        slope.clone(),
        output_kernel,
        new_kv_kernel,
    )

    output_triton, new_kv_triton = triton_lightning_attn_decode(
        q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
    )

    if (
        torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
        and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
        and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
        and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
    ):
        print("✅ All implementations match")
    else:
        print("❌ Implementations differ")


batch_size_range = [i for i in range(1, 65)]  # 1 to 128
configs = [(bs,) for bs in batch_size_range]


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
        x_vals=[list(_) for _ in configs],
        line_arg="provider",
        line_vals=["naive", "kernel", "triton"],
        line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
        styles=[("blue", "-"), ("red", "-"), ("green", "-")],
        ylabel="us",
        plot_name="lightning-attention-decode-performance",
        args={},
    )
)
def benchmark(batch_size, provider):
    dtype = torch.bfloat16
    device = torch.device("cuda")
    num_heads = 64
    head_dim = 96
    seq_len = 1

    q = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    k = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    v = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
    slope = torch.randn(num_heads, 1, 1, device=device)

    quantiles = [0.5, 0.2, 0.8]

    if provider == "naive":
249
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
250
251
252
253
254
255
256
257
258
259
            lambda: lightning_attention_decode_naive(
                q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
            ),
            quantiles=quantiles,
        )
    elif provider == "kernel":
        output = torch.empty(
            batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
        )
        new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
260
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
261
262
263
264
265
266
267
268
269
270
271
272
            lambda: lightning_attention_decode_kernel(
                q.clone(),
                k.clone(),
                v.clone(),
                past_kv.clone(),
                slope.clone(),
                output,
                new_kv,
            ),
            quantiles=quantiles,
        )
    elif provider == "triton":
273
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
            lambda: triton_lightning_attn_decode(
                q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
            ),
            quantiles=quantiles,
        )

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
        help="Path to save lightning attention decode benchmark results",
    )
    args = parser.parse_args()

    # Run correctness test
    calculate_diff(batch_size=4)

    # Run performance benchmark
    benchmark.run(print_data=True)