ipex.py 2.11 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
3
from text_generation_server.layers.attention.kv_cache import KVCache
Wang, Yi's avatar
Wang, Yi committed
4
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
5
from text_generation_server.layers.attention import Seqlen
drbh's avatar
drbh committed
6
from typing import Optional
7
8
9
10
11

SUPPORTS_WINDOWING = False


def attention(
12
13
14
15
16
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
17
18
    seqlen: Seqlen,
    block_tables: torch.Tensor,
19
20
21
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
drbh's avatar
drbh committed
22
    softcap: Optional[float] = None,
23
):
24
25
26
27
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

    out = torch.empty_like(query)
28

29
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
drbh's avatar
drbh committed
30
    ipex.llm.functional.varlen_attention(
31
32
33
        query.contiguous() if query.device.type == "xpu" else query,
        key.contiguous() if key.device.type == "xpu" else key,
        value.contiguous() if value.device.type == "xpu" else value,
34
        out,
35
36
37
38
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
39
40
41
        0.0,
        softmax_scale,
        False,
42
        causal,
43
44
45
46
        False,
        None,
    )

drbh's avatar
drbh committed
47
48
    return out

49
50
51

def paged_attention(
    query: torch.Tensor,
52
    kv_cache: KVCache,
53
54
55
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
56
    seqlen: Seqlen,
57
    max_s: int,
drbh's avatar
drbh committed
58
    softcap: Optional[float] = None,
59
):
60
61
62
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

63
    out = torch.empty_like(query)
64
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
65
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
66
67
        out,
        query,
68
69
        kv_cache.key,
        kv_cache.value,
70
71
72
        kv_head_mapping,
        softmax_scale,
        block_tables,
73
        input_lengths,
Wang, Yi's avatar
Wang, Yi committed
74
        BLOCK_SIZE,
75
76
77
        max_s,
        None,
    )
78
    return out
79
80
81
82
83
84
85


__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]