protocol.py 10.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import asyncio
from abc import ABC, abstractmethod
5
from typing import AsyncGenerator, List, Mapping, Optional
6

7
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
8
9
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
10
from vllm.inputs.data import PromptType, TokensPrompt
11
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
12
from vllm.inputs.preprocess import InputPreprocessor
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
17
18
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
19
from vllm.sampling_params import BeamSearchParams, SamplingParams
20
from vllm.transformers_utils.tokenizer import AnyTokenizer
21
from vllm.utils import Device, collect_from_async_generator, random_uuid
22

23
logger = init_logger(__name__)
24

25
26

class EngineClient(ABC):
27
    """Protocol class for Clients to Engine"""
28
29

    @property
30
    @abstractmethod
31
32
33
34
    def is_running(self) -> bool:
        ...

    @property
35
    @abstractmethod
36
37
38
39
    def is_stopped(self) -> bool:
        ...

    @property
40
    @abstractmethod
41
42
43
    def errored(self) -> bool:
        ...

44
    @property
45
    @abstractmethod
46
47
    def dead_error(self) -> BaseException:
        ...
48

49
    @abstractmethod
50
    def generate(
51
        self,
52
        prompt: PromptType,
53
54
55
56
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
57
58
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
59
    ) -> AsyncGenerator[RequestOutput, None]:
60
        """Generate outputs for a request."""
61
        ...
62

63
64
    async def beam_search(
        self,
65
        prompt: PromptType,
66
67
68
69
70
71
72
73
74
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
75
        include_stop_str_in_output = params.include_stop_str_in_output
76

77
78
79
        preprocessor = await self.get_input_preprocessor()
        tokenizer_group = preprocessor.get_tokenizer_group()
        tokenizer = await tokenizer_group.get_lora_tokenizer_async()
80

81
82
83
        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
84
            processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
85
86
87
88
89
90

        prompt_token_ids = processed_inputs["prompt_token_ids"]
        prompt_text = processed_inputs.get("prompt")
        multi_modal_data = processed_inputs.get("multi_modal_data")
        mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

91
        tokenized_length = len(prompt_token_ids)
92
93
94
95

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

96
97
98
99
100
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
101
        all_beams = [
102
103
            BeamSearchSequence(tokens=prompt_token_ids,
                               cum_logprob=0,
104
                               logprobs=[],
105
106
                               multi_modal_data=multi_modal_data,
                               mm_processor_kwargs=mm_processor_kwargs)
107
        ]
108
109
110
111
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
112
113
114
                TokensPrompt(prompt_token_ids=beam.tokens,
                             multi_modal_data=beam.multi_modal_data,
                             mm_processor_kwargs=beam.mm_processor_kwargs)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
142
143
144
145
146
147
148
149
150
151
152
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens +
                                    [token_id] if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=tokenizer.eos_token_id))
153
                        else:
154
155
156
157
158
159
160
161
162
163
164
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs))
165
166
167
168
169
170
171
172
173

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
174
175
176
177
178
179
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
180
181
182

        beam_search_output = RequestOutput(
            request_id=request_id,
183
            prompt=prompt_text,
184
            outputs=[
185
186
187
188
189
190
191
192
193
                CompletionOutput(text=beam.text,
                                 cumulative_logprob=beam.cum_logprob,
                                 token_ids=beam.tokens[tokenized_length:],
                                 index=i,
                                 logprobs=beam.logprobs,
                                 finish_reason=beam.finish_reason if
                                 beam.finish_reason is not None else "length",
                                 stop_reason=beam.stop_reason)
                for (i, beam) in enumerate(best_beams)
194
195
            ],
            finished=True,
196
            prompt_token_ids=prompt_token_ids,
197
198
199
200
201
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
202
    def encode(
203
        self,
204
        prompt: PromptType,
205
206
207
208
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
209
        priority: int = 0,
210
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
211
        """Generate outputs for a request from a pooling model."""
212
        ...
213

214
    @abstractmethod
215
216
217
218
219
220
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """
221
        ...
222

223
    @abstractmethod
224
225
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
226
        ...
227

228
    @abstractmethod
229
230
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
231
232
233
234
235
236
        ...

    @abstractmethod
    async def get_input_preprocessor(self) -> InputPreprocessor:
        """Get the input processor of the vLLM engine."""
        ...
237

238
    @abstractmethod
239
240
241
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
242
243
244
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
245

246
    @abstractmethod
247
    async def is_tracing_enabled(self) -> bool:
248
        ...
249

250
    @abstractmethod
251
252
253
254
255
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
256
        ...
257

258
    @abstractmethod
259
260
    async def check_health(self) -> None:
        """Raise if unhealthy"""
261
        ...
262

263
    @abstractmethod
264
265
266
267
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

268
    @abstractmethod
269
270
271
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...
272

273
    @abstractmethod
274
275
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
276
277
278
        """Reset the prefix cache"""
        ...

279
280
281
282
283
284
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
285
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
286
287
288
        """Wake up the engine"""
        ...

289
290
291
292
293
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

294
295
296
297
    @abstractmethod
    async def add_lora(self, lora_request: LoRARequest) -> None:
        """Load a new LoRA adapter into the engine for future requests."""
        ...