xpu.py 1.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import intel_extension_for_pytorch as ipex
import torch

SUPPORTS_WINDOWING = False


def attention(
    q,
    k,
    v,
    out,
    cu_seqlens,
    max_s,
    softmax_scale,
    window_size_left=-1,
):
17
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    return ipex.llm.functional.varlen_attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        cu_seqlens,
        max_s,
        max_s,
        0.0,
        softmax_scale,
        False,
        True,
        False,
        None,
    )


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    ipex.llm.modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache, slots
    )


def paged_attention(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
    input_lengths: torch.Tensor,
    max_s: int,
):
    query = query.contiguous()
    block_size = value_cache.shape[3]
    return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
        out,
        query,
        key_cache,
        value_cache,
        kv_head_mapping,
        softmax_scale,
        block_tables,
        input_lengths,
        block_size,
        max_s,
        None,
    )