"docs/features/quantization/int4.md" did not exist on "d0bc2f810b7a34247154b078c2429bf62519e9ca"
protocol.py 9.04 KB
Newer Older
1
2
3
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
4

5
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
6
7
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
8
from vllm.inputs.data import PromptType, TokensPrompt
9
from vllm.inputs.preprocess import InputPreprocessor
10
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
14
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
                          RequestOutput)
15
16
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.sampling_params import BeamSearchParams, SamplingParams
18
from vllm.transformers_utils.tokenizer import AnyTokenizer
19
from vllm.utils import collect_from_async_generator, random_uuid
20

21
logger = init_logger(__name__)
22

23
24

class EngineClient(ABC):
25
    """Protocol class for Clients to Engine"""
26
27

    @property
28
    @abstractmethod
29
30
31
32
    def is_running(self) -> bool:
        ...

    @property
33
    @abstractmethod
34
35
36
37
    def is_stopped(self) -> bool:
        ...

    @property
38
    @abstractmethod
39
40
41
    def errored(self) -> bool:
        ...

42
    @property
43
    @abstractmethod
44
45
    def dead_error(self) -> BaseException:
        ...
46

47
    @abstractmethod
48
    def generate(
49
        self,
50
        prompt: PromptType,
51
52
53
54
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
55
56
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
57
    ) -> AsyncGenerator[RequestOutput, None]:
58
        """Generate outputs for a request."""
59
        ...
60

61
62
    async def beam_search(
        self,
63
64
        prompt: Union[PromptType, List[int]],
        model_config: ModelConfig,
65
66
67
68
69
70
71
72
73
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
74
        include_stop_str_in_output = params.include_stop_str_in_output
75

76
77
78
79
80
81
82
83
84
        tokenizer = await self.get_tokenizer()
        input_preprocessor = InputPreprocessor(model_config, tokenizer)

        (prompt_text, prompt_token_ids, multi_modal_data,
         mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
             prompt,
             request_id=request_id,
         )
        tokenized_length = len(prompt_token_ids)
85
86
87
88

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

89
90
91
92
93
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
94
        all_beams = [
95
96
            BeamSearchSequence(tokens=prompt_token_ids,
                               cum_logprob=0,
97
                               logprobs=[],
98
99
                               multi_modal_data=multi_modal_data,
                               mm_processor_kwargs=mm_processor_kwargs)
100
        ]
101
102
103
104
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
105
106
107
                TokensPrompt(prompt_token_ids=beam.tokens,
                             multi_modal_data=beam.multi_modal_data,
                             mm_processor_kwargs=beam.mm_processor_kwargs)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
135
136
137
138
139
140
141
142
143
144
145
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens +
                                    [token_id] if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=tokenizer.eos_token_id))
146
                        else:
147
148
149
150
151
152
153
154
155
156
157
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs))
158
159
160
161
162
163
164
165
166

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
167
168
169
170
171
172
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
173
174
175

        beam_search_output = RequestOutput(
            request_id=request_id,
176
            prompt=prompt_text,
177
            outputs=[
178
179
180
181
182
183
184
185
186
                CompletionOutput(text=beam.text,
                                 cumulative_logprob=beam.cum_logprob,
                                 token_ids=beam.tokens[tokenized_length:],
                                 index=i,
                                 logprobs=beam.logprobs,
                                 finish_reason=beam.finish_reason if
                                 beam.finish_reason is not None else "length",
                                 stop_reason=beam.stop_reason)
                for (i, beam) in enumerate(best_beams)
187
188
            ],
            finished=True,
189
            prompt_token_ids=prompt_token_ids,
190
191
192
193
194
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
195
    def encode(
196
        self,
197
        prompt: PromptType,
198
199
200
201
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
202
        priority: int = 0,
203
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
204
        """Generate outputs for a request from an embedding model."""
205
        ...
206

207
    @abstractmethod
208
209
210
211
212
213
214
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """

215
    @abstractmethod
216
217
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
218
        ...
219

220
    @abstractmethod
221
    async def get_decoding_config(self) -> DecodingConfig:
222
        ...
223
224
        """Get the decoding configuration of the vLLM engine."""

225
    @abstractmethod
226
227
228
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
229
230
231
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
232

233
    @abstractmethod
234
    async def is_tracing_enabled(self) -> bool:
235
        ...
236

237
    @abstractmethod
238
239
240
241
242
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
243
        ...
244

245
    @abstractmethod
246
247
    async def check_health(self) -> None:
        """Raise if unhealthy"""
248
        ...
249

250
    @abstractmethod
251
252
253
254
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

255
    @abstractmethod
256
257
258
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...