xpu.py 9.04 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.attention.backends.registry import AttentionBackendEnum
12
13
from vllm.logger import init_logger

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.config import VllmConfig
18
19
20
else:
    VllmConfig = None

21
logger = init_logger(__name__)
22
23
24
25


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
26
    device_name: str = "xpu"
27
    device_type: str = "xpu"
28
    dispatch_key: str = "XPU"
29
30
31
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
32
    dist_backend: str = "ccl"  # ccl | xccl
33
    device_control_env_var: str = "ZE_AFFINITY_MASK"
34

35
    @classmethod
36
37
38
39
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
40

41
    @classmethod
42
43
    def get_attn_backend_cls(
        cls,
44
        selected_backend: "AttentionBackendEnum",
45
46
        head_size: int,
        dtype: torch.dtype,
47
        kv_cache_dtype: str | None,
48
49
50
51
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse,
52
        attn_type: str | None = None,
53
    ) -> str:
54
55
56
57
58
59
60
61
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

62
        if use_sparse:
63
            raise NotImplementedError("Sparse Attention is not supported on XPU.")
64
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
65
            logger.info_once("Using Triton backend.")
66
67
            return AttentionBackendEnum.TRITON_ATTN.get_path()
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
68
            logger.info_once("Using Flash Attention backend.")
69
            return AttentionBackendEnum.FLASH_ATTN.get_path()
70
71
72
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
73
                f"with use_mla: {use_mla}"
74
            )
75

76
        logger.info("Using Flash Attention backend.")
77
        return AttentionBackendEnum.FLASH_ATTN.get_path()
78

79
80
81
82
83
84
85
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

86
    @classmethod
87
    def get_device_capability(
88
89
        cls,
        device_id: int = 0,
90
    ) -> DeviceCapability | None:
91
92
93
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
94

95
96
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
97
        return torch.xpu.get_device_name(device_id)
98

99
100
    @classmethod
    def get_punica_wrapper(cls) -> str:
101
102
103
104
105
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
106

107
108
109
110
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
111

112
    @classmethod
113
114
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
115
    ) -> "AttentionBackendEnum":
116
        return AttentionBackendEnum.FLASH_ATTN
117

118
119
    @classmethod
    def inference_mode(cls):
120
        return torch.no_grad()
121
122
123

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
124
        cache_config = vllm_config.cache_config
125
        model_config = vllm_config.model_config
126
        # in V1(or with ipex chunked prefill) block_size is 64
127
        if cache_config and cache_config.block_size is None:
128
            cache_config.block_size = 64
129

130
        # lazy import to avoid circular import
131
        from vllm.config import CompilationMode, CUDAGraphMode
132

133
        compilation_config = vllm_config.compilation_config
134
135
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
136

137
        assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
138
            "CUDA graph mode should be NONE on XPU"
139
        )
140

141
        if vllm_config.lora_config is not None:
142
            compilation_config.mode = CompilationMode.NONE
143

144
145
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
146
        parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
147
148
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
149
150

        if parallel_config.distributed_executor_backend is None:
151
152
153
154
            if parallel_config.world_size > 1:
                parallel_config.distributed_executor_backend = "ray"
            else:
                parallel_config.distributed_executor_backend = "uni"
155
156
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
157
            # spawn needs calling `if __name__ == '__main__':`
158
            # fork is not supported for xpu start new process.
159
160
161
            if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
                logger.warning(
162
163
164
165
166
167
168
                    "Please use spawn as start method if you want to use mp."
                )
        elif (
            parallel_config.distributed_executor_backend != "ray"
            and parallel_config.distributed_executor_backend != "uni"
            and parallel_config.distributed_executor_backend != "external_launcher"
        ):
169
170
171
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
172
173
                parallel_config.distributed_executor_backend,
            )
174
            parallel_config.distributed_executor_backend = "ray"
175

176
        if model_config and model_config.use_mla:
177
178
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
179
180
                "prefill and prefix caching to be disabled."
            )
181
182
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
183
                vllm_config.model_config.max_model_len,
184
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
185
            )
186

187
188
189
190
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

191
192
193
194
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return False

195
196
    @classmethod
    def is_pin_memory_available(cls):
197
        return True
198
199

    @classmethod
200
    def get_current_memory_usage(
201
        cls, device: torch.types.Device | None = None
202
    ) -> float:
203
204
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
205

206
207
208
209
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        return torch.float8_e5m2

210
211
212
213
214
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

215
216
217
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
218
219
220

    @classmethod
    def device_count(cls) -> int:
221
        return torch.xpu.device_count()
222
223

    @classmethod
224
225
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
226
227
228
229
230
231
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
232
233
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
234
235
236
237

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()