"tests/vscode:/vscode.git/clone" did not exist on "d83becd503660fb876ea42beaa9f63217b857b99"
gpu_executor.py 6.46 KB
Newer Older
1
from typing import Dict, List, Set, Tuple
2
3
4

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
5
from vllm.lora.request import LoRARequest
6
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
7
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
8
9
10
11
12
13
14
                        make_async)

logger = init_logger(__name__)


class GPUExecutor(ExecutorBase):

15
    def _init_executor(self) -> None:
16
        """Initialize the worker and load the model.
17

18
19
20
21
22
23
24
        If speculative decoding is enabled, we instead create the speculative
        worker.
        """
        if self.speculative_config is None:
            self._init_non_spec_worker()
        else:
            self._init_spec_worker()
25

26
    def _init_non_spec_worker(self):
27
28
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
29
        from vllm.worker.worker import Worker
30
31
32
33
34
35
36

        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        self.driver_worker = Worker(
37
38
39
40
41
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
42
            load_config=self.load_config,
43
44
45
46
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
47
            vision_language_config=self.vision_language_config,
48
49
            is_driver_worker=True,
        )
50
        self.driver_worker.init_device()
51
52
        self.driver_worker.load_model()

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def _init_spec_worker(self):
        """Initialize a SpecDecodeWorker, using a draft model for proposals.
        """
        assert self.speculative_config is not None

        from vllm.spec_decode.multi_step_worker import MultiStepWorker
        from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
        from vllm.worker.worker import Worker

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())

        target_worker = Worker(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
71
            load_config=self.load_config,
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=True,
        )

        draft_worker = MultiStepWorker(
            model_config=self.speculative_config.draft_model_config,
            parallel_config=self.speculative_config.draft_parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
86
            # TODO allow draft-model specific load config.
87
            load_config=self.load_config,
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=True,
        )

        spec_decode_worker = SpecDecodeWorker.from_workers(
            proposer_worker=draft_worker, scorer_worker=target_worker)

        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        self.driver_worker = spec_decode_worker

        # Load model handled in spec decode worker.
        self.driver_worker.init_device()

107
    def determine_num_available_blocks(self) -> Tuple[int, int]:
108
109
        """Determine the number of available KV blocks by invoking the
        underlying worker.
110
        """
111
        return self.driver_worker.determine_num_available_blocks()
112

113
114
115
116
117
118
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
119
120
121
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")

122
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
123

124
125
126
127
128
129
130
131
    def execute_model(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
        num_lookahead_slots: int,
    ) -> List[SamplerOutput]:
132
133
134
135
136
        output = self.driver_worker.execute_model(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
137
            num_lookahead_slots=num_lookahead_slots,
138
139
140
141
142
143
144
145
146
147
148
        )
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.remove_lora(lora_id)

149
    def list_loras(self) -> Set[int]:
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        return self.driver_worker.list_loras()

    def check_health(self) -> None:
        # GPUExecutor will always be healthy as long as
        # it's running.
        return


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> SamplerOutput:
        output = await make_async(self.driver_worker.execute_model)(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy)
        return output