xpu_executor.py 1.78 KB
Newer Older
1
from typing import Callable, List, Optional, Tuple, Type, Union
2
3
4
5

from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
6
7
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
8
from vllm.utils import make_async
9
from vllm.worker.worker_base import WorkerBase
10
11
12
13
14
15

logger = init_logger(__name__)


class XPUExecutor(GPUExecutor):

16
17
    uses_ray: bool = False

18
19
20
21
    def _init_executor(self) -> None:
        assert self.device_config.device_type == "xpu"
        assert self.speculative_config is None, (
            "Speculative decoding not yet supported for XPU backend")
22

23
        GPUExecutor._init_executor(self)
24

25
26
27
    def _get_worker_module_and_class(
            self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
        worker_class_fn = None
28
        if self.speculative_config is not None:
29
30
            raise NotImplementedError(
                "XPU does not support speculative decoding")
31
32
33
        else:
            worker_module_name = "vllm.worker.xpu_worker"
            worker_class_name = "XPUWorker"
34
        return (worker_module_name, worker_class_name, worker_class_fn)
35
36

    def execute_model(
37
38
        self, execute_model_req: ExecuteModelRequest
    ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
39
40
41
42
43
44
45
46
47
48
49
50
51
        output = self.driver_worker.execute_model(execute_model_req)
        return output


class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        output = await make_async(self.driver_worker.execute_model
                                  )(execute_model_req=execute_model_req)
        return output