protocol.py 9.39 KB
Newer Older
1
2
import asyncio
from abc import ABC, abstractmethod
3
from typing import AsyncGenerator, List, Mapping, Optional
4

5
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
6
7
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
8
from vllm.inputs.data import PromptType, TokensPrompt
9
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
10
from vllm.inputs.preprocess import InputPreprocessor
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
from vllm.model_executor.layers.sampler import SamplerOutput
14
15
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
                          RequestOutput)
16
17
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
18
from vllm.sampling_params import BeamSearchParams, SamplingParams
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import collect_from_async_generator, random_uuid
21

22
logger = init_logger(__name__)
23

24
25

class EngineClient(ABC):
26
    """Protocol class for Clients to Engine"""
27
28

    @property
29
    @abstractmethod
30
31
32
33
    def is_running(self) -> bool:
        ...

    @property
34
    @abstractmethod
35
36
37
38
    def is_stopped(self) -> bool:
        ...

    @property
39
    @abstractmethod
40
41
42
    def errored(self) -> bool:
        ...

43
    @property
44
    @abstractmethod
45
46
    def dead_error(self) -> BaseException:
        ...
47

48
    @abstractmethod
49
    def generate(
50
        self,
51
        prompt: PromptType,
52
53
54
55
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
56
57
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
58
    ) -> AsyncGenerator[RequestOutput, None]:
59
        """Generate outputs for a request."""
60
        ...
61

62
63
    async def beam_search(
        self,
64
        prompt: PromptType,
65
        model_config: ModelConfig,
66
67
68
69
70
71
72
73
74
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
75
        include_stop_str_in_output = params.include_stop_str_in_output
76

77
78
79
        tokenizer = await self.get_tokenizer()
        input_preprocessor = InputPreprocessor(model_config, tokenizer)

80
81
82
83
84
85
86
87
88
89
90
91
92
        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
            processed_inputs = input_preprocessor._prompt_to_llm_inputs(
                prompt,
                request_id=request_id,
            )

        prompt_token_ids = processed_inputs["prompt_token_ids"]
        prompt_text = processed_inputs.get("prompt")
        multi_modal_data = processed_inputs.get("multi_modal_data")
        mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

93
        tokenized_length = len(prompt_token_ids)
94
95
96
97

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

98
99
100
101
102
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
103
        all_beams = [
104
105
            BeamSearchSequence(tokens=prompt_token_ids,
                               cum_logprob=0,
106
                               logprobs=[],
107
108
                               multi_modal_data=multi_modal_data,
                               mm_processor_kwargs=mm_processor_kwargs)
109
        ]
110
111
112
113
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
114
115
116
                TokensPrompt(prompt_token_ids=beam.tokens,
                             multi_modal_data=beam.multi_modal_data,
                             mm_processor_kwargs=beam.mm_processor_kwargs)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
144
145
146
147
148
149
150
151
152
153
154
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens +
                                    [token_id] if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=tokenizer.eos_token_id))
155
                        else:
156
157
158
159
160
161
162
163
164
165
166
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs))
167
168
169
170
171
172
173
174
175

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
176
177
178
179
180
181
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
182
183
184

        beam_search_output = RequestOutput(
            request_id=request_id,
185
            prompt=prompt_text,
186
            outputs=[
187
188
189
190
191
192
193
194
195
                CompletionOutput(text=beam.text,
                                 cumulative_logprob=beam.cum_logprob,
                                 token_ids=beam.tokens[tokenized_length:],
                                 index=i,
                                 logprobs=beam.logprobs,
                                 finish_reason=beam.finish_reason if
                                 beam.finish_reason is not None else "length",
                                 stop_reason=beam.stop_reason)
                for (i, beam) in enumerate(best_beams)
196
197
            ],
            finished=True,
198
            prompt_token_ids=prompt_token_ids,
199
200
201
202
203
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
204
    def encode(
205
        self,
206
        prompt: PromptType,
207
208
209
210
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
211
        priority: int = 0,
212
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
213
        """Generate outputs for a request from an embedding model."""
214
        ...
215

216
    @abstractmethod
217
218
219
220
221
222
223
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """

224
    @abstractmethod
225
226
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
227
        ...
228

229
    @abstractmethod
230
    async def get_decoding_config(self) -> DecodingConfig:
231
        ...
232
233
        """Get the decoding configuration of the vLLM engine."""

234
    @abstractmethod
235
236
237
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
238
239
240
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
241

242
    @abstractmethod
243
    async def is_tracing_enabled(self) -> bool:
244
        ...
245

246
    @abstractmethod
247
248
249
250
251
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
252
        ...
253

254
    @abstractmethod
255
256
    async def check_health(self) -> None:
        """Raise if unhealthy"""
257
        ...
258

259
    @abstractmethod
260
261
262
263
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

264
    @abstractmethod
265
266
267
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...