gpu_executor.py 5.72 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Set, Tuple
2
3
4

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
5
from vllm.lora.request import LoRARequest
6
from vllm.sequence import ExecuteModelRequest, SamplerOutput
7
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
8
                        make_async)
9
from vllm.worker.worker_base import WorkerWrapperBase
10
11
12
13
14
15

logger = init_logger(__name__)


class GPUExecutor(ExecutorBase):

16
    def _init_executor(self) -> None:
17
        """Initialize the worker and load the model.
18

19
20
21
22
23
24
25
        If speculative decoding is enabled, we instead create the speculative
        worker.
        """
        if self.speculative_config is None:
            self._init_non_spec_worker()
        else:
            self._init_spec_worker()
26

27
28
29
30
31
32
33
34
35
36
    def _get_worker_kwargs(
            self,
            local_rank: int = 0,
            rank: int = 0,
            distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
        """Return worker init args for a given rank."""
        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return dict(
37
38
39
40
41
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
42
            load_config=self.load_config,
43
44
            local_rank=local_rank,
            rank=rank,
45
46
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
47
            vision_language_config=self.vision_language_config,
48
49
50
51
52
53
54
55
56
57
            is_driver_worker=rank == 0,
        )

    def _create_worker(self,
                       local_rank: int = 0,
                       rank: int = 0,
                       distributed_init_method: Optional[str] = None):
        wrapper = WorkerWrapperBase(
            worker_module_name="vllm.worker.worker",
            worker_class_name="Worker",
58
        )
59
60
61
62
63
64
65
66
67
        wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
                                                      distributed_init_method))
        return wrapper.worker

    def _init_non_spec_worker(self):
        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        self.driver_worker = self._create_worker()
68
        self.driver_worker.init_device()
69
70
        self.driver_worker.load_model()

71
72
73
74
75
76
77
    def _init_spec_worker(self):
        """Initialize a SpecDecodeWorker, using a draft model for proposals.
        """
        assert self.speculative_config is not None

        from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker

78
        target_worker = self._create_worker()
79

80
81
82
        draft_worker_kwargs = self._get_worker_kwargs()
        # Override draft-model specific worker args.
        draft_worker_kwargs.update(
83
84
            model_config=self.speculative_config.draft_model_config,
            parallel_config=self.speculative_config.draft_parallel_config,
85
86
87
88
            ngram_prompt_lookup_max=self.speculative_config.
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.speculative_config.
            ngram_prompt_lookup_min,
89
            # TODO allow draft-model specific load config.
90
            #load_config=self.load_config,
91
92
        )

93
94
95
96
        spec_decode_worker = SpecDecodeWorker.create_worker(
            scorer_worker=target_worker,
            draft_worker_kwargs=draft_worker_kwargs,
        )
97
98
99
100
101
102
103
104
105

        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        self.driver_worker = spec_decode_worker

        # Load model handled in spec decode worker.
        self.driver_worker.init_device()

106
    def determine_num_available_blocks(self) -> Tuple[int, int]:
107
108
        """Determine the number of available KV blocks by invoking the
        underlying worker.
109
        """
110
        return self.driver_worker.determine_num_available_blocks()
111

112
113
114
115
116
117
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
118
119
        logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
120

121
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
122

123
    def execute_model(
124
125
126
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(execute_model_req)
127
128
129
130
131
132
133
134
135
136
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.remove_lora(lora_id)

137
    def list_loras(self) -> Set[int]:
138
139
140
141
142
143
144
145
146
147
148
149
        return self.driver_worker.list_loras()

    def check_health(self) -> None:
        # GPUExecutor will always be healthy as long as
        # it's running.
        return


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
150
        execute_model_req: ExecuteModelRequest,
151
    ) -> List[SamplerOutput]:
152
153
        output = await make_async(self.driver_worker.execute_model
                                  )(execute_model_req=execute_model_req, )
154
        return output