protocol.py 7.45 KB
Newer Older
1
2
3
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
4

5
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
6
7
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
8
9
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
10
from vllm.lora.request import LoRARequest
11
from vllm.model_executor.layers.sampler import SamplerOutput
12
13
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
                          RequestOutput)
14
15
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.sampling_params import BeamSearchParams, SamplingParams
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.utils import collect_from_async_generator, random_uuid
19

20
logger = init_logger(__name__)
21

22
23

class EngineClient(ABC):
24
    """Protocol class for Clients to Engine"""
25
26

    @property
27
    @abstractmethod
28
29
30
31
    def is_running(self) -> bool:
        ...

    @property
32
    @abstractmethod
33
34
35
36
    def is_stopped(self) -> bool:
        ...

    @property
37
    @abstractmethod
38
39
40
    def errored(self) -> bool:
        ...

41
    @property
42
    @abstractmethod
43
44
    def dead_error(self) -> BaseException:
        ...
45

46
    @abstractmethod
47
    def generate(
48
        self,
49
        prompt: PromptType,
50
51
52
53
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
54
55
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
56
    ) -> AsyncGenerator[RequestOutput, None]:
57
        """Generate outputs for a request."""
58
        ...
59

60
61
    async def beam_search(
        self,
62
        prompt: Union[str, List[int]],
63
64
65
66
67
68
69
70
71
72
73
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty

        tokenizer = await self.get_tokenizer(lora_request=None)
74
75
76
77
78
79
80
        if isinstance(prompt, str):
            tokenized_prompt = tokenizer.encode(prompt)
            prompt_text = prompt
        else:
            tokenized_prompt = prompt
            prompt_text = None
        tokenized_length = len(tokenized_prompt)
81
82
83
84
85
86
87

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
                                            temperature=temperature)
88
89
90
91
92
        all_beams = [
            BeamSearchSequence(tokens=tokenized_prompt,
                               logprobs=[],
                               cum_logprob=0)
        ]
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        new_beam = BeamSearchSequence(
                            tokens=current_beam.tokens + [token_id],
125
                            logprobs=current_beam.logprobs + [logprobs],
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                            cum_logprob=current_beam.cum_logprob +
                            logprob_obj.logprob)

                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
                            completed.append(new_beam)
                        else:
                            new_beams.append(new_beam)

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
143
144
145
146
147
148
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
149
150
151

        beam_search_output = RequestOutput(
            request_id=request_id,
152
            prompt=prompt_text,
153
154
155
156
            outputs=[
                CompletionOutput(
                    text=beam.text,
                    cumulative_logprob=beam.cum_logprob,
157
                    token_ids=beam.tokens[tokenized_length:],
158
                    index=i,
159
                    logprobs=beam.logprobs,
160
161
162
                ) for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
163
            prompt_token_ids=tokenized_prompt,
164
165
166
167
168
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
169
    def encode(
170
        self,
171
        prompt: PromptType,
172
173
174
175
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
176
        priority: int = 0,
177
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
178
        """Generate outputs for a request from an embedding model."""
179
        ...
180

181
    @abstractmethod
182
183
184
185
186
187
188
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """

189
    @abstractmethod
190
191
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
192
        ...
193

194
    @abstractmethod
195
    async def get_decoding_config(self) -> DecodingConfig:
196
        ...
197
198
        """Get the decoding configuration of the vLLM engine."""

199
    @abstractmethod
200
201
202
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
203
204
205
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
206

207
    @abstractmethod
208
    async def is_tracing_enabled(self) -> bool:
209
        ...
210

211
    @abstractmethod
212
213
214
215
216
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
217
        ...
218

219
    @abstractmethod
220
221
    async def check_health(self) -> None:
        """Raise if unhealthy"""
222
        ...
223

224
    @abstractmethod
225
226
227
228
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

229
    @abstractmethod
230
231
232
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...