ipex.py 2.01 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
Wang, Yi's avatar
Wang, Yi committed
3
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
4
from text_generation_server.layers.attention import Seqlen
drbh's avatar
drbh committed
5
from typing import Optional
6
7

SUPPORTS_WINDOWING = False
8
PREFILL_IN_KV_CACHE = False
9
10
11


def attention(
12
13
14
15
16
    q: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    seqlen: Seqlen,
    block_tables: torch.Tensor,
17
18
    softmax_scale,
    window_size_left=-1,
19
    causal=True,
drbh's avatar
drbh committed
20
    softcap: Optional[float] = None,
21
):
22
23
    out = torch.empty_like(q)

24
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
drbh's avatar
drbh committed
25
    ipex.llm.functional.varlen_attention(
26
27
28
        q.contiguous() if q.device.type == "xpu" else q,
        key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
        value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
29
        out,
30
31
32
33
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
34
35
36
        0.0,
        softmax_scale,
        False,
37
        causal,
38
39
40
41
        False,
        None,
    )

drbh's avatar
drbh committed
42
43
    return out

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    ipex.llm.modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache, slots
    )


def paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
64
    seqlen: Seqlen,
65
    max_s: int,
drbh's avatar
drbh committed
66
    softcap: Optional[float] = None,
67
):
68
    out = torch.empty_like(query)
69
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
70
71
72
73
74
75
76
        out,
        query,
        key_cache,
        value_cache,
        kv_head_mapping,
        softmax_scale,
        block_tables,
77
        seqlen.input_lengths,
Wang, Yi's avatar
Wang, Yi committed
78
        BLOCK_SIZE,
79
80
81
        max_s,
        None,
    )
82
    return out