attention.py 203 Bytes
Newer Older
1
2
3
4
import torch


def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
5
    torch.ops.sgl_kernel.lightning_attention_decode.default(
6
7
        q, k, v, past_kv, slope, output, new_kv
    )