"vllm/vscode:/vscode.git/clone" did not exist on "cbdc3a13fea673a9f46538ddba108c56515efb94"
_xpu_ops.py 5.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING
5
6

import torch
7
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
8
9
10
11
12

from vllm.logger import init_logger

logger = init_logger(__name__)

13
if TYPE_CHECKING:
14

15
16
17
18
19
20
21
    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake
22

23
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
24

25
26
27
28
29
30
    @register_fake("_xpu_C::fp8_gemm_w8a16")
    def _fp8_gemm_w8a16_fake(
        input: torch.Tensor,
        q_weight: torch.Tensor,
        weight_scale: torch.Tensor,
        bias: torch.Tensor | None = None,
31
    ) -> torch.Tensor:
32
33
34
35
        input_2d = input.view(-1, input.shape[-1])
        M = input_2d.size(0)
        N = q_weight.size(1)
        return torch.empty((M, N), dtype=input.dtype, device=input.device)
36
37


38
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
    @register_fake("_xpu_C::int4_gemm_w4a16")
    def _int4_gemm_w4a16_fake(
        input: torch.Tensor,
        q_weight: torch.Tensor,
        bias: torch.Tensor | None,
        weight_scale: torch.Tensor,
        qzeros: torch.Tensor,
        group_size: int,
        group_idx: torch.Tensor | None = None,
    ) -> torch.Tensor:
        input_2d = input.view(-1, input.shape[-1])
        M = input_2d.size(0)
        N = q_weight.size(1)
        return torch.empty((M, N), dtype=input.dtype, device=input.device)
54

55

56
class xpu_ops:
57
58
59
60
61
62
63
64
    @staticmethod
    def flash_attn_varlen_func(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        cu_seqlens_q: torch.Tensor,
        max_seqlen_q: int,
        max_seqlen_k: int,
65
66
67
68
69
        softmax_scale: float | None = None,
        causal: bool = False,
        out: torch.Tensor | None = None,
        block_table: torch.Tensor | None = None,
        alibi_slopes: torch.Tensor | None = None,
70
71
        window_size: list[int] | None = None,
        softcap: float | None = 0.0,
72
        seqused_k: torch.Tensor | None = None,
73
        cu_seqlens_k: torch.Tensor | None = None,
74
75
        # passed in qwen vl
        dropout_p: float = 0.0,
76
        # The following parameters are not used in xpu kernel currently,
77
78
79
80
81
82
        # we keep API compatible to CUDA's.
        scheduler_metadata=None,
        fa_version: int = 2,
        q_descale=None,
        k_descale=None,
        v_descale=None,
83
        num_splits=0,
84
        return_softmax_lse: bool | None = False,
85
        s_aux: torch.Tensor | None = None,
86
    ):
87
88
89
90
91
92
93
94
95
96
97
98
        assert cu_seqlens_k is not None or seqused_k is not None, (
            "cu_seqlens_k or seqused_k must be provided"
        )
        assert cu_seqlens_k is None or seqused_k is None, (
            "cu_seqlens_k and seqused_k cannot be provided at the same time"
        )
        assert block_table is None or seqused_k is not None, (
            "when enable block_table, seqused_k is needed"
        )
        assert block_table is not None or cu_seqlens_k is not None, (
            "when block_table is disabled, cu_seqlens_k is needed"
        )
99
100
        if out is None:
            out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
101
102
103
104
105
        real_window_size: tuple[int, int]
        if window_size is None:
            real_window_size = (-1, -1)
        else:
            assert len(window_size) == 2
106
            real_window_size = (window_size[0], window_size[1])  # noqa: F841
107

108
109
        # In encode attention, v maybe not contiguous and current
        # kernel can't handle it
110
        if block_table is None:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            v = v.contiguous()
        return flash_attn_varlen_func(
            out=out,
            q=q.contiguous(),
            k=k,
            v=v,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            seqused_k=seqused_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=softmax_scale,
            causal=causal,
            block_table=block_table,
            s_aux=s_aux,
            window_size=real_window_size,
            # alibi_slopes = alibi_slopes,
            # softcap=softcap,
            return_softmax_lse=return_softmax_lse,
        )
131
132
133

    @staticmethod
    def get_scheduler_metadata(
134
135
136
137
138
139
140
141
142
        batch_size,
        max_seqlen_q,
        max_seqlen_k,
        num_heads_q,
        num_heads_kv,
        headdim,
        cache_seqlens: torch.Tensor,
        qkv_dtype=torch.bfloat16,
        headdim_v=None,
143
144
145
146
        cu_seqlens_q: torch.Tensor | None = None,
        cu_seqlens_k_new: torch.Tensor | None = None,
        cache_leftpad: torch.Tensor | None = None,
        page_size: int | None = None,
147
148
149
150
151
152
153
        max_seqlen_k_new=0,
        causal=False,
        window_size=(-1, -1),  # -1 means infinite context window
        has_softcap=False,
        num_splits=0,  # Can be tuned for speed
        pack_gqa=None,  # Can be tuned for speed
        sm_margin=0,  # Can be tuned if some SMs are used for communication
154
155
    ) -> None:
        logger.warning_once(
156
            "get_scheduler_metadata is not implemented for xpu_ops, returning None."
157
        )
158
        return None