"vllm/model_executor/layers/attention.py" did not exist on "3e9f991d6acd7efd90f04f1f530b837a40c93442"
xpu.py 9.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING, Optional
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
13
from vllm.logger import init_logger

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.selector import AttentionSelectorConfig
18
    from vllm.config import VllmConfig
19
20
21
else:
    VllmConfig = None

22
logger = init_logger(__name__)
23
24
25
26


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
27
    device_name: str = "xpu"
28
    device_type: str = "xpu"
29
    dispatch_key: str = "XPU"
30
31
32
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
33
    dist_backend: str = "ccl"  # ccl | xccl
34
    device_control_env_var: str = "ZE_AFFINITY_MASK"
35

36
    @classmethod
37
38
39
40
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
41

42
    @classmethod
43
44
    def get_attn_backend_cls(
        cls,
45
        selected_backend: "AttentionBackendEnum",
46
        attn_selector_config: "AttentionSelectorConfig",
47
    ) -> str:
48
49
50
51
52
53
54
55
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

56
        if attn_selector_config.use_sparse:
57
            raise NotImplementedError("Sparse Attention is not supported on XPU.")
58
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
59
            logger.info_once("Using Triton backend.")
60
61
            return AttentionBackendEnum.TRITON_ATTN.get_path()
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
62
            logger.info_once("Using Flash Attention backend.")
63
            return AttentionBackendEnum.FLASH_ATTN.get_path()
64
65
66
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
67
                f"with use_mla: {attn_selector_config.use_mla}"
68
            )
69

70
        logger.info("Using Flash Attention backend.")
71
        return AttentionBackendEnum.FLASH_ATTN.get_path()
72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        # XPU only supports FLASH_ATTN for vision attention.
        return [
            AttentionBackendEnum.FLASH_ATTN,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: "
                f"{cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
        )
        return AttentionBackendEnum.FLASH_ATTN

101
102
103
104
105
106
107
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

108
    @classmethod
109
    def get_device_capability(
110
111
        cls,
        device_id: int = 0,
112
    ) -> DeviceCapability | None:
113
114
115
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
116

117
118
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
119
        return torch.xpu.get_device_name(device_id)
120

121
122
    @classmethod
    def get_punica_wrapper(cls) -> str:
123
124
125
126
127
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
128

129
130
131
132
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
133

134
135
    @classmethod
    def inference_mode(cls):
136
        return torch.no_grad()
137
138
139

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
140
        cache_config = vllm_config.cache_config
141
        model_config = vllm_config.model_config
142
        # in V1(or with ipex chunked prefill) block_size is 64
143
        if cache_config and cache_config.block_size is None:
144
            cache_config.block_size = 64
145

146
        # lazy import to avoid circular import
147
        from vllm.config import CompilationMode, CUDAGraphMode
148

149
        compilation_config = vllm_config.compilation_config
150
151
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
152

153
        assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
154
            "CUDA graph mode should be NONE on XPU"
155
        )
156

157
        if vllm_config.lora_config is not None:
158
            compilation_config.mode = CompilationMode.NONE
159

160
161
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
162
        parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
163
164
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
165
166

        if parallel_config.distributed_executor_backend is None:
167
168
169
170
            if parallel_config.world_size > 1:
                parallel_config.distributed_executor_backend = "ray"
            else:
                parallel_config.distributed_executor_backend = "uni"
171
172
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
173
            # spawn needs calling `if __name__ == '__main__':`
174
            # fork is not supported for xpu start new process.
175
176
177
            if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
                logger.warning(
178
179
180
181
182
183
184
                    "Please use spawn as start method if you want to use mp."
                )
        elif (
            parallel_config.distributed_executor_backend != "ray"
            and parallel_config.distributed_executor_backend != "uni"
            and parallel_config.distributed_executor_backend != "external_launcher"
        ):
185
186
187
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
188
189
                parallel_config.distributed_executor_backend,
            )
190
            parallel_config.distributed_executor_backend = "ray"
191

192
        if model_config and model_config.use_mla:
193
194
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
195
196
                "prefill and prefix caching to be disabled."
            )
197
198
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
199
                vllm_config.model_config.max_model_len,
200
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
201
            )
202

203
204
205
206
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

207
208
209
210
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return False

211
212
    @classmethod
    def is_pin_memory_available(cls):
213
        return True
214
215

    @classmethod
216
    def get_current_memory_usage(
217
        cls, device: torch.types.Device | None = None
218
    ) -> float:
219
220
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
221

222
223
224
225
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        return torch.float8_e5m2

226
227
228
229
230
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

231
232
233
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
234
235
236

    @classmethod
    def device_count(cls) -> int:
237
        return torch.xpu.device_count()
238
239

    @classmethod
240
241
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
242
243
244
245
246
247
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
248
249
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
250
251
252
253

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()