ipex.py 2.4 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
3
from text_generation_server.layers.attention.kv_cache import KVCache
Wang, Yi's avatar
Wang, Yi committed
4
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
5
from text_generation_server.layers.attention import Seqlen
drbh's avatar
drbh committed
6
from typing import Optional
7
8
9
10
11

SUPPORTS_WINDOWING = False


def attention(
12
13
14
15
16
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
17
18
    seqlen: Seqlen,
    block_tables: torch.Tensor,
19
20
21
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
drbh's avatar
drbh committed
22
    softcap: Optional[float] = None,
23
):
24
25
26
27
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

    out = torch.empty_like(query)
28

29
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
drbh's avatar
drbh committed
30
    ipex.llm.functional.varlen_attention(
31
32
33
        query.contiguous() if query.device.type == "xpu" else query,
        key.contiguous() if key.device.type == "xpu" else key,
        value.contiguous() if value.device.type == "xpu" else value,
34
        out,
35
36
37
38
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
39
40
41
        0.0,
        softmax_scale,
        False,
42
        causal,
43
44
45
46
        False,
        None,
    )

drbh's avatar
drbh committed
47
48
    return out

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    ipex.llm.modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache, slots
    )


def paged_attention(
    query: torch.Tensor,
64
    kv_cache: KVCache,
65
66
67
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
68
    seqlen: Seqlen,
69
    max_s: int,
drbh's avatar
drbh committed
70
    softcap: Optional[float] = None,
71
):
72
73
74
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

75
    out = torch.empty_like(query)
76
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
77
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
78
79
        out,
        query,
80
81
        kv_cache.key,
        kv_cache.value,
82
83
84
        kv_head_mapping,
        softmax_scale,
        block_tables,
85
        input_lengths,
Wang, Yi's avatar
Wang, Yi committed
86
        BLOCK_SIZE,
87
88
89
        max_s,
        None,
    )
90
    return out
91
92
93
94
95
96
97
98


__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
    "reshape_and_cache",
]