platform.py 2.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from vllm.logger import init_logger
from vllm.platforms.xpu import XPUPlatform

from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum
from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum

logger = init_logger(__name__)


class XPUOmniPlatform(OmniPlatform, XPUPlatform):
    """XPU/Intel GPU implementation of OmniPlatform.

    Inherits all XPU-specific implementations from vLLM's XPUPlatform,
    and adds Omni-specific interfaces from OmniPlatform.
    """

    _omni_enum = OmniPlatformEnum.XPU

    @classmethod
    def get_omni_ar_worker_cls(cls) -> str:
        return "vllm_omni.platforms.xpu.worker.xpu_ar_worker.XPUARWorker"

    @classmethod
    def get_omni_generation_worker_cls(cls) -> str:
        return "vllm_omni.platforms.xpu.worker.xpu_generation_worker.XPUGenerationWorker"

    @classmethod
    def get_diffusion_attn_backend_cls(
        cls,
        selected_backend: str | None,
        head_size: int,
    ) -> str:
        if selected_backend is not None:
            backend_upper = selected_backend.upper()
            backend = DiffusionAttentionBackendEnum[backend_upper]
            logger.info("Using diffusion attention backend '%s'", backend_upper)
            return backend.get_path()

        logger.info("Defaulting to diffusion attention backend SDPA")
        return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path()

    @classmethod
    def supports_torch_inductor(cls) -> bool:
        return True

    @classmethod
    def get_default_stage_config_path(cls) -> str:
        return "vllm_omni/platforms/xpu/stage_configs"

    @classmethod
    def get_torch_device(cls, local_rank: int | None = None) -> torch.device:
        if local_rank is None:
            return torch.device("xpu")
        return torch.device("xpu", local_rank)

    @classmethod
    def get_device_count(cls) -> int:
        return torch.xpu.device_count()

    @classmethod
    def get_device_version(cls) -> str | None:
        # XPU does not have a version string like CUDA
        return None

    @classmethod
    def synchronize(cls) -> None:
        torch.xpu.synchronize()

    @classmethod
    def get_free_memory(cls, device: torch.device | None = None) -> int:
        if device is None:
            device_id = 0
        else:
            device_id = device.index if device.index is not None else 0
        props = torch.xpu.get_device_properties(device_id)
        return props.total_memory