ipex.py 2.2 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
Wang, Yi's avatar
Wang, Yi committed
3
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
4
from text_generation_server.layers.attention import Seqlen
drbh's avatar
drbh committed
5
from typing import Optional
6
7

SUPPORTS_WINDOWING = False
8
PREFILL_IN_KV_CACHE = False
9
10
11


def attention(
12
13
14
15
16
    q: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    seqlen: Seqlen,
    block_tables: torch.Tensor,
17
18
    softmax_scale,
    window_size_left=-1,
19
    causal=True,
drbh's avatar
drbh committed
20
    softcap: Optional[float] = None,
21
):
22
23
    out = torch.empty_like(q)

24
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
drbh's avatar
drbh committed
25
    ipex.llm.functional.varlen_attention(
26
27
28
        q.contiguous() if q.device.type == "xpu" else q,
        key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
        value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
29
        out,
30
31
32
33
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
34
35
36
        0.0,
        softmax_scale,
        False,
37
        causal,
38
39
40
41
        False,
        None,
    )

drbh's avatar
drbh committed
42
43
    return out

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    ipex.llm.modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache, slots
    )


def paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
64
    seqlen: Seqlen,
65
    max_s: int,
drbh's avatar
drbh committed
66
    softcap: Optional[float] = None,
67
):
68
    out = torch.empty_like(query)
69
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
70
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
71
72
73
74
75
76
77
        out,
        query,
        key_cache,
        value_cache,
        kv_head_mapping,
        softmax_scale,
        block_tables,
78
        input_lengths,
Wang, Yi's avatar
Wang, Yi committed
79
        BLOCK_SIZE,
80
81
82
        max_s,
        None,
    )
83
    return out
84
85
86
87
88
89
90
91
92


__all__ = [
    "PREFILL_IN_KV_CACHE",
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
    "reshape_and_cache",
]