xpu_executor.py 3.28 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2
3
4
5

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
6
7
8
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig)
9
10
11
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
12
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
13
14
15
16
17
18
19
from vllm.utils import make_async

logger = init_logger(__name__)


class XPUExecutor(GPUExecutor):

20
21
    uses_ray: bool = False

22
23
24
25
26
27
28
29
30
    def __init__(
        self,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
31
        prompt_adapter_config: Optional[PromptAdapterConfig],
32
        speculative_config: Optional[SpeculativeConfig],
33
        observability_config: Optional[ObservabilityConfig],
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    ) -> None:
        assert device_config.device_type == "xpu"
        assert (not speculative_config
                ), "Speculative decoding not yet supported for XPU backend"

        model_config = _verify_and_get_model_config(model_config)

        self.model_config = model_config
        self.cache_config = cache_config
        self.load_config = load_config
        self.lora_config = lora_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
48
        self.prompt_adapter_config = prompt_adapter_config
49
        self.speculative_config = None
50
        self.observability_config = observability_config
51
52
53
54

        # Instantiate the worker and load the model to GPU.
        self._init_executor()

55
56
    def _get_worker_module_and_class(self) -> Tuple[str, str]:
        if self.speculative_config is not None:
57
58
            raise NotImplementedError(
                "XPU does not support speculative decoding")
59
60
61
62
        else:
            worker_module_name = "vllm.worker.xpu_worker"
            worker_class_name = "XPUWorker"
        return (worker_module_name, worker_class_name)
63
64

    def execute_model(
65
66
        self, execute_model_req: ExecuteModelRequest
    ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        output = self.driver_worker.execute_model(execute_model_req)
        return output


class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        output = await make_async(self.driver_worker.execute_model
                                  )(execute_model_req=execute_model_req)
        return output


def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
    if config.dtype == torch.bfloat16:
        logger.warning(
            "bfloat16 is not fully supported on XPU, casting to float16.")
        config.dtype = torch.float16
    if not config.enforce_eager:
        logger.warning(
            "CUDA graph is not supported on XPU, fallback to the eager "
            "mode.")
        config.enforce_eager = True
    return config