xpu.py 9.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import TYPE_CHECKING, Optional
6

7
8
import torch

9
import vllm.envs as envs
10
from vllm.logger import init_logger
11
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
12
13
14

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

15
if TYPE_CHECKING:
16
    from vllm.config import ModelConfig, VllmConfig
17
else:
18
    ModelConfig = None
19
20
    VllmConfig = None

21
logger = init_logger(__name__)
22
23
24
25


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
26
    device_name: str = "xpu"
27
    device_type: str = "xpu"
28
    dispatch_key: str = "XPU"
29
30
31
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
32
    dist_backend: str = "ccl"  # ccl | xccl
33
    device_control_env_var: str = "ZE_AFFINITY_MASK"
34

35
    @classmethod
36
37
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
38
39
                             block_size: int, use_v1: bool, use_mla: bool,
                             has_sink: bool) -> str:
40
        use_v1 = envs.VLLM_USE_V1
41
42
        if not use_v1:
            raise ValueError("XPU backend only supports V1.")
43
44
45
46
47
48
49
50
51
52
53
54
55
        TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
        FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
        if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
            logger.info_once("Using Triton backend on V1 engine.")
            return TRITON_ATTN_VLLM_V1
        elif selected_backend == _Backend.FLASH_ATTN:
            logger.info_once("Using Flash Attention backend on V1 engine.")
            return FLASH_ATTN_V1
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
                f"with use_v1: {use_v1} use_mla: {use_mla}")

56
57
        logger.info("Using Flash Attention backend on V1 engine.")
        return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
58

59
60
61
62
63
64
65
66
67
68
69
70
71
    @classmethod
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
        """
        Check if the kv_cache_dtype is supported.
        XPU only support fp8 kv cache with triton backend.
        """
        if envs.is_set("VLLM_ATTENTION_BACKEND") and \
            envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
            return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]

        return False

72
73
74
75
76
77
78
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

79
    @classmethod
80
    def get_device_capability(
81
82
83
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
84
85
86
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
87

88
89
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
90
        return torch.xpu.get_device_name(device_id)
91

92
93
    @classmethod
    def get_punica_wrapper(cls) -> str:
94
        return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
95

96
97
98
99
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
100

101
102
    @classmethod
    def inference_mode(cls):
103
        return torch.no_grad()
104
105
106

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
107
        cache_config = vllm_config.cache_config
108
        model_config = vllm_config.model_config
109
        # in V1(or with ipex chunked prefill) block_size is 64
110
        if cache_config and cache_config.block_size is None:
111
            cache_config.block_size = 64
112

113
        # lazy import to avoid circular import
114
        from vllm.config import CompilationLevel, CUDAGraphMode
115
        compilation_config = vllm_config.compilation_config
116
117
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
118
119
120

        assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
            "CUDA graph mode should be NONE on XPU"
121

122
123
124
        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION

125
126
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
127
        parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
128
129

        if parallel_config.distributed_executor_backend is None:
130
131
132
133
            if parallel_config.world_size > 1:
                parallel_config.distributed_executor_backend = "ray"
            else:
                parallel_config.distributed_executor_backend = "uni"
134
135
136
137
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
            # spawn needs calling `if __name__ == '__main__':``
            # fork is not supported for xpu start new process.
138
139
140
141
            if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
                logger.warning(
                    "Please use spawn as start method if you want to use mp.")
142
143
144
145
        elif (parallel_config.distributed_executor_backend != "ray"
              and parallel_config.distributed_executor_backend != "uni"
              and parallel_config.distributed_executor_backend
              != "external_launcher"):
146
147
148
149
150
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
                parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "ray"
151

152
        if model_config and model_config.use_mla:
153
154
155
156
157
158
159
160
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)
161
        from vllm.v1.attention.backends.utils import set_kv_cache_layout
162

163
164
165
        set_kv_cache_layout("NHD")
        logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
                    "only NHD layout is supported by XPU attention kernels.")
166

167
168
169
170
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

171
172
173
174
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return False

175
176
    @classmethod
    def is_pin_memory_available(cls):
177
        return True
178
179
180
181
182
183
184

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
185

186
187
188
189
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        return torch.float8_e5m2

190
191
192
193
194
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

195
196
197
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
198
199
200

    @classmethod
    def device_count(cls) -> int:
201
        return torch.xpu.device_count()
202
203
204
205
206
207
208
209
210
211
212

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")
213
214
215
216

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()