"router/src/infer/chat_template.rs" did not exist on "709d8936f68002c2244e245607c6b88d658ebe6f"
ipex.py 2.18 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
3
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
Wang, Yi's avatar
Wang, Yi committed
4
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
5
from text_generation_server.layers.attention import Seqlen
drbh's avatar
drbh committed
6
from typing import Optional
7
8
9
10
11

SUPPORTS_WINDOWING = False


def attention(
12
13
14
15
16
    *,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: KVCache,
17
    kv_scales: KVScales,
18
19
    seqlen: Seqlen,
    block_tables: torch.Tensor,
20
21
22
    softmax_scale: float,
    window_size_left: int = -1,
    causal: bool = True,
drbh's avatar
drbh committed
23
    softcap: Optional[float] = None,
24
):
25
26
27
28
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

    out = torch.empty_like(query)
29

30
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
drbh's avatar
drbh committed
31
    ipex.llm.functional.varlen_attention(
32
33
34
        query.contiguous() if query.device.type == "xpu" else query,
        key.contiguous() if key.device.type == "xpu" else key,
        value.contiguous() if value.device.type == "xpu" else value,
35
        out,
36
37
38
39
        seqlen.cu_seqlen_q,
        seqlen.cu_seqlen_q,
        seqlen.max_q,
        seqlen.max_q,
40
41
42
        0.0,
        softmax_scale,
        False,
43
        causal,
44
45
46
47
        False,
        None,
    )

drbh's avatar
drbh committed
48
49
    return out

50
51
52

def paged_attention(
    query: torch.Tensor,
53
    kv_cache: KVCache,
54
55
56
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
57
    seqlen: Seqlen,
58
    max_s: int,
59
60
    *,
    kv_scales: KVScales,
drbh's avatar
drbh committed
61
    softcap: Optional[float] = None,
62
):
63
64
65
    if softcap is not None:
        raise NotImplementedError("softcap is not available in IPEX")

66
    out = torch.empty_like(query)
67
    input_lengths = seqlen.input_lengths + seqlen.cache_lengths
68
    ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
69
70
        out,
        query,
71
72
        kv_cache.key,
        kv_cache.value,
73
74
75
        kv_head_mapping,
        softmax_scale,
        block_tables,
76
        input_lengths,
Wang, Yi's avatar
Wang, Yi committed
77
        BLOCK_SIZE,
78
79
80
        max_s,
        None,
    )
81
    return out
82
83
84
85
86
87
88


__all__ = [
    "SUPPORTS_WINDOWING",
    "attention",
    "paged_attention",
]