ipex.py 2.18 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
import intel_extension_for_pytorch as ipex
import torch
jixx's avatar
jixx committed
3
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
jixx's avatar
init  
jixx committed
4
5
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
jixx's avatar
jixx committed
6
from typing import Optional
jixx's avatar
init  
jixx committed
7
8
9
10
11

SUPPORTS_WINDOWING = False


def attention(
jixx's avatar
jixx committed
12
13
14
15
16
17
18
19
20
21
22
23
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
    kv_scales: KVScales,
    seqlen: Seqlen,
    block_tables: torch.Tensor,
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
    softcap: Optional[float] = None,
jixx's avatar
init  
jixx committed
24
):
jixx's avatar
jixx committed
25
26
27
28
29
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

    out = torch.empty_like(query)

jixx's avatar
init  
jixx committed
30
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
jixx's avatar
jixx committed
31
32
33
34
    ipex.llm.functional.varlen_attention(
        query.contiguous() if query.device.type == "xpu" else query,
        key.contiguous() if key.device.type == "xpu" else key,
        value.contiguous() if value.device.type == "xpu" else value,
jixx's avatar
init  
jixx committed
35
        out,
jixx's avatar
jixx committed
36
37
38
39
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
jixx's avatar
init  
jixx committed
40
41
42
43
44
45
46
47
        0.0,
        softmax_scale,
        False,
        causal,
        False,
        None,
    )

jixx's avatar
jixx committed
48
    return out
jixx's avatar
init  
jixx committed
49
50
51
52


def paged_attention(
    query: torch.Tensor,
jixx's avatar
jixx committed
53
    kv_cache: KVCache,
jixx's avatar
init  
jixx committed
54
55
56
57
58
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
    seqlen: Seqlen,
    max_s: int,
jixx's avatar
jixx committed
59
60
61
    *,
    kv_scales: KVScales,
    softcap: Optional[float] = None,
jixx's avatar
init  
jixx committed
62
):
jixx's avatar
jixx committed
63
64
65
66
67
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

    out = torch.empty_like(query)
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
jixx's avatar
init  
jixx committed
68
69
70
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
        out,
        query,
jixx's avatar
jixx committed
71
72
        kv_cache.key,
        kv_cache.value,
jixx's avatar
init  
jixx committed
73
74
75
        kv_head_mapping,
        softmax_scale,
        block_tables,
jixx's avatar
jixx committed
76
        input_lengths,
jixx's avatar
init  
jixx committed
77
78
79
80
81
        BLOCK_SIZE,
        max_s,
        None,
    )
    return out
jixx's avatar
jixx committed
82
83
84
85
86
87
88


__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]