protocol.py 10.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import asyncio
from abc import ABC, abstractmethod
5
from typing import AsyncGenerator, List, Mapping, Optional
6

7
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
8
9
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
10
from vllm.inputs.data import PromptType, TokensPrompt
11
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
12
from vllm.inputs.preprocess import InputPreprocessor
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
17
18
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
19
from vllm.sampling_params import BeamSearchParams, SamplingParams
20
from vllm.transformers_utils.tokenizer import AnyTokenizer
21
from vllm.utils import collect_from_async_generator, random_uuid
22

23
logger = init_logger(__name__)
24

25
26

class EngineClient(ABC):
27
    """Protocol class for Clients to Engine"""
28
29

    @property
30
    @abstractmethod
31
32
33
34
    def is_running(self) -> bool:
        ...

    @property
35
    @abstractmethod
36
37
38
39
    def is_stopped(self) -> bool:
        ...

    @property
40
    @abstractmethod
41
42
43
    def errored(self) -> bool:
        ...

44
    @property
45
    @abstractmethod
46
47
    def dead_error(self) -> BaseException:
        ...
48

49
    @abstractmethod
50
    def generate(
51
        self,
52
        prompt: PromptType,
53
54
55
56
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
57
58
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
59
    ) -> AsyncGenerator[RequestOutput, None]:
60
        """Generate outputs for a request."""
61
        ...
62

63
64
    async def beam_search(
        self,
65
        prompt: PromptType,
66
67
68
69
70
71
72
73
74
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
75
        include_stop_str_in_output = params.include_stop_str_in_output
76

77
78
79
        preprocessor = await self.get_input_preprocessor()
        tokenizer_group = preprocessor.get_tokenizer_group()
        tokenizer = await tokenizer_group.get_lora_tokenizer_async()
80

81
82
83
        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
84
            processed_inputs = preprocessor._prompt_to_llm_inputs(
85
86
87
88
89
90
91
92
93
                prompt,
                request_id=request_id,
            )

        prompt_token_ids = processed_inputs["prompt_token_ids"]
        prompt_text = processed_inputs.get("prompt")
        multi_modal_data = processed_inputs.get("multi_modal_data")
        mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

94
        tokenized_length = len(prompt_token_ids)
95
96
97
98

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

99
100
101
102
103
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
104
        all_beams = [
105
106
            BeamSearchSequence(tokens=prompt_token_ids,
                               cum_logprob=0,
107
                               logprobs=[],
108
109
                               multi_modal_data=multi_modal_data,
                               mm_processor_kwargs=mm_processor_kwargs)
110
        ]
111
112
113
114
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
115
116
117
                TokensPrompt(prompt_token_ids=beam.tokens,
                             multi_modal_data=beam.multi_modal_data,
                             mm_processor_kwargs=beam.mm_processor_kwargs)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
145
146
147
148
149
150
151
152
153
154
155
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens +
                                    [token_id] if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=tokenizer.eos_token_id))
156
                        else:
157
158
159
160
161
162
163
164
165
166
167
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs))
168
169
170
171
172
173
174
175
176

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
177
178
179
180
181
182
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
183
184
185

        beam_search_output = RequestOutput(
            request_id=request_id,
186
            prompt=prompt_text,
187
            outputs=[
188
189
190
191
192
193
194
195
196
                CompletionOutput(text=beam.text,
                                 cumulative_logprob=beam.cum_logprob,
                                 token_ids=beam.tokens[tokenized_length:],
                                 index=i,
                                 logprobs=beam.logprobs,
                                 finish_reason=beam.finish_reason if
                                 beam.finish_reason is not None else "length",
                                 stop_reason=beam.stop_reason)
                for (i, beam) in enumerate(best_beams)
197
198
            ],
            finished=True,
199
            prompt_token_ids=prompt_token_ids,
200
201
202
203
204
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
205
    def encode(
206
        self,
207
        prompt: PromptType,
208
209
210
211
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
212
        priority: int = 0,
213
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
214
        """Generate outputs for a request from a pooling model."""
215
        ...
216

217
    @abstractmethod
218
219
220
221
222
223
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """
224
        ...
225

226
    @abstractmethod
227
228
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
229
        ...
230

231
    @abstractmethod
232
233
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
234
235
236
237
238
239
        ...

    @abstractmethod
    async def get_input_preprocessor(self) -> InputPreprocessor:
        """Get the input processor of the vLLM engine."""
        ...
240

241
    @abstractmethod
242
243
244
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
245
246
247
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
248

249
    @abstractmethod
250
    async def is_tracing_enabled(self) -> bool:
251
        ...
252

253
    @abstractmethod
254
255
256
257
258
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
259
        ...
260

261
    @abstractmethod
262
263
    async def check_health(self) -> None:
        """Raise if unhealthy"""
264
        ...
265

266
    @abstractmethod
267
268
269
270
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

271
    @abstractmethod
272
273
274
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...
275

276
277
278
279
280
    @abstractmethod
    async def reset_prefix_cache(self) -> None:
        """Reset the prefix cache"""
        ...

281
282
283
284
285
286
287
288
289
290
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
    async def wake_up(self) -> None:
        """Wake up the engine"""
        ...

291
292
293
294
295
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

296
297
298
299
    @abstractmethod
    async def add_lora(self, lora_request: LoRARequest) -> None:
        """Load a new LoRA adapter into the engine for future requests."""
        ...