ipex.py 1.47 KB
Newer Older
1
2
import intel_extension_for_pytorch as ipex
import torch
Wang, Yi's avatar
Wang, Yi committed
3
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
4
5
6
7
8
9
10
11
12
13
14
15
16
17

SUPPORTS_WINDOWING = False


def attention(
    q,
    k,
    v,
    out,
    cu_seqlens,
    max_s,
    softmax_scale,
    window_size_left=-1,
):
18
    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    return ipex.llm.functional.varlen_attention(
        q,
        k,
        v,
        out,
        cu_seqlens,
        cu_seqlens,
        max_s,
        max_s,
        0.0,
        softmax_scale,
        False,
        True,
        False,
        None,
    )


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    ipex.llm.modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache, slots
    )


def paged_attention(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
    input_lengths: torch.Tensor,
    max_s: int,
):
    return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
        out,
        query,
        key_cache,
        value_cache,
        kv_head_mapping,
        softmax_scale,
        block_tables,
        input_lengths,
Wang, Yi's avatar
Wang, Yi committed
69
        BLOCK_SIZE,
70
71
72
        max_s,
        None,
    )