openvino_executor.py 7.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
from typing import List, Set, Tuple

import openvino as ov
import openvino.properties.hint as hints
import torch

import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
from vllm.platforms import current_platform
14
from vllm.sequence import ExecuteModelRequest
15
16
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
                        get_open_port, make_async)
17
from vllm.worker.worker_base import WorkerWrapperBase
18
19
20
21
22
23

logger = init_logger(__name__)


class OpenVINOExecutor(ExecutorBase):

24
25
    uses_ray: bool = False

26
27
28
    def _init_executor(self) -> None:
        assert self.device_config.device_type == "openvino"
        assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
29
30
        assert current_platform.is_openvino_cpu() or \
            current_platform.is_openvino_gpu(), \
31
32
33
            "OpenVINO backend supports only CPU and GPU devices"

        self.ov_core = ov.Core()
34
        self.model_config = _verify_and_get_model_config(self.model_config)
35
36
        self.cache_config = _verify_and_get_cache_config(
            self.ov_core, self.cache_config)
37
38
39
40
41
42

        # Instantiate the worker and load the model to CPU.
        self._init_worker()

    def _init_worker(self):

43
        wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
44
45
46

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
47
        self.driver_worker = wrapper.init_worker(
48
            ov_core=self.ov_core,
49
            vllm_config=self.vllm_config,
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            kv_cache_dtype=self.cache_config.cache_dtype,
            is_driver_worker=True,
        )
        self.driver_worker.init_device()
        self.driver_worker.load_model()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks by invoking the
        underlying worker.
        """
        return self.driver_worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache by invoking the underlying worker."""
        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
71
72
73
74
75
76
77
        # NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
        # is located on CPU memory but is referred as `gpu block`.
        # Because we want to reuse the existing block management procedure.
        device_blocks = num_gpu_blocks
        swap_blocks = num_cpu_blocks
        logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
                    envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(execute_model_req)
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.driver_worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.driver_worker.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        return self.driver_worker.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.driver_worker.list_loras()

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def add_prompt_adapter(self, prompt_adapter_request) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the OPENVINO backend.")

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the OPENVINO backend.")

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the OPENVINO backend.")

    def list_prompt_adapters(self) -> Set[int]:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the OPENVINO backend.")

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def check_health(self) -> None:
        # OpenVINOExecutor will always be healthy as long as
        # it's running.
        return


class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):

    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        output = await make_async(self.driver_worker.execute_model
                                  )(execute_model_req=execute_model_req, )
        return output

    async def check_health_async(self) -> None:
        # OpenVINOExecutor will always be healthy as long as
        # it's running.
        return


def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
    if config.dtype != torch.float32:
        logger.warning(
            f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}."  # noqa: G004, E501
        )
        config.dtype = torch.float32
    if not config.enforce_eager:
        logger.warning(
            "CUDA graph is not supported on OpenVINO backend, fallback to the "
            "eager mode.")
        config.enforce_eager = True
    return config


149
150
def _verify_and_get_cache_config(ov_core: ov.Core,
                                 config: CacheConfig) -> CacheConfig:
151
    if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
152
        if not current_platform.is_openvino_cpu():
153
154
155
156
157
158
159
            logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
                        "ignored for GPU, f16 data type will be used.")
            config.cache_dtype = ov.Type.f16
        else:
            logger.info("KV cache type is overridden to u8 via "
                        "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
            config.cache_dtype = ov.Type.u8
160
    else:
161
        if current_platform.is_openvino_cpu():
162
163
164
165
166
167
168
            ov_device = envs.VLLM_OPENVINO_DEVICE
            inference_precision = ov_core.get_property(
                ov_device, hints.inference_precision)
            if inference_precision == ov.Type.bf16:
                config.cache_dtype = ov.Type.bf16
            else:
                config.cache_dtype = ov.Type.f16
169
170
171
        else:
            config.cache_dtype = ov.Type.f16

172
    if current_platform.is_openvino_cpu():
173
174
175
176
177
178
179
180
181
182
183
        if config.block_size != 32:
            logger.info(
                f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}"  # noqa: G004, E501
            )
            config.block_size = 32
    else:
        if config.block_size != 16:
            logger.info(
                f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}"  # noqa: G004, E501
            )
            config.block_size = 16
184
185
186

    kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
    if kv_cache_space >= 0:
187
        if kv_cache_space == 0 and current_platform.is_openvino_cpu():
188
            config.openvino_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
189
190
191
192
            logger.warning(
                "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
                "for OpenVINO backend is not set, using 4 by default.")
        else:
193
            config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore
194
195
196
197
198
199
    else:
        raise RuntimeError(
            "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
            f" {kv_cache_space}, expect a positive integer value.")

    return config