"vllm/entrypoints/openai/completion/protocol.py" did not exist on "092bb73b8a36ccdb6d6bbac897ba3aa79f660e36"
gpu_executor.py 4.56 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2
3
4

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
5
from vllm.lora.request import LoRARequest
6
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
7
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
8
                        make_async)
9
from vllm.worker.worker_base import WorkerWrapperBase
10
11
12
13
14
15

logger = init_logger(__name__)


class GPUExecutor(ExecutorBase):

16
    def _init_executor(self) -> None:
17
18
        """Initialize the worker and load the model.
        """
19
20
21
22
23
24
        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        self.driver_worker = self._create_worker()
        self.driver_worker.init_device()
        self.driver_worker.load_model()
25

26
27
28
29
30
31
32
33
34
35
    def _get_worker_kwargs(
            self,
            local_rank: int = 0,
            rank: int = 0,
            distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
        """Return worker init args for a given rank."""
        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return dict(
36
37
38
39
40
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
41
            load_config=self.load_config,
42
43
            local_rank=local_rank,
            rank=rank,
44
45
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
46
            vision_language_config=self.vision_language_config,
47
            speculative_config=self.speculative_config,
48
49
50
51
52
53
54
            is_driver_worker=rank == 0,
        )

    def _create_worker(self,
                       local_rank: int = 0,
                       rank: int = 0,
                       distributed_init_method: Optional[str] = None):
55
56
57
58
59
60
61
62

        if self.speculative_config is None:
            worker_module_name = "vllm.worker.worker"
            worker_class_name = "Worker"
        else:
            worker_module_name = "vllm.spec_decode.spec_decode_worker"
            worker_class_name = "create_spec_worker"

63
        wrapper = WorkerWrapperBase(
64
65
            worker_module_name=worker_module_name,
            worker_class_name=worker_class_name,
66
        )
67
68
69
70
        wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
                                                      distributed_init_method))
        return wrapper.worker

71
    def determine_num_available_blocks(self) -> Tuple[int, int]:
72
73
        """Determine the number of available KV blocks by invoking the
        underlying worker.
74
        """
75
        return self.driver_worker.determine_num_available_blocks()
76

77
78
79
80
81
82
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
83
84
        logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
85

86
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
87

88
    def execute_model(
89
90
        self, execute_model_req: ExecuteModelRequest
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
91
        output = self.driver_worker.execute_model(execute_model_req)
92
93
94
95
96
97
98
99
100
101
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.remove_lora(lora_id)

102
    def list_loras(self) -> Set[int]:
103
104
105
106
107
108
109
110
111
112
113
114
        return self.driver_worker.list_loras()

    def check_health(self) -> None:
        # GPUExecutor will always be healthy as long as
        # it's running.
        return


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
115
        execute_model_req: ExecuteModelRequest,
116
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
117
118
        output = await make_async(self.driver_worker.execute_model
                                  )(execute_model_req=execute_model_req, )
119
        return output