"vscode:/vscode.git/clone" did not exist on "82a006beebf03c4f7bd600ab68f15b3325feb8e4"
protocol.py 6.93 KB
Newer Older
1
2
3
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
4

5
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
6
7
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
8
9
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
10
from vllm.lora.request import LoRARequest
11
from vllm.model_executor.layers.sampler import SamplerOutput
12
13
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
                          RequestOutput)
14
15
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.sampling_params import BeamSearchParams, SamplingParams
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.utils import collect_from_async_generator, random_uuid
19

20
logger = init_logger(__name__)
21

22
23

class EngineClient(ABC):
24
    """Protocol class for Clients to Engine"""
25
26

    @property
27
    @abstractmethod
28
29
30
31
    def is_running(self) -> bool:
        ...

    @property
32
    @abstractmethod
33
34
35
36
    def is_stopped(self) -> bool:
        ...

    @property
37
    @abstractmethod
38
39
40
    def errored(self) -> bool:
        ...

41
    @property
42
    @abstractmethod
43
44
    def dead_error(self) -> BaseException:
        ...
45

46
    @abstractmethod
47
    def generate(
48
        self,
49
        prompt: PromptType,
50
51
52
53
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
54
55
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
56
    ) -> AsyncGenerator[RequestOutput, None]:
57
        """Generate outputs for a request."""
58
        ...
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    async def beam_search(
        self,
        prompt: Union[PromptType, List[int]],
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty

        tokenizer = await self.get_tokenizer(lora_request=None)
        tokenizedPrompt = prompt if isinstance(
            prompt, list) else tokenizer.encode(prompt)
        tokenizedLength = len(tokenizedPrompt)

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
                                            temperature=temperature)
        all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        new_beam = BeamSearchSequence(
                            tokens=current_beam.tokens + [token_id],
                            cum_logprob=current_beam.cum_logprob +
                            logprob_obj.logprob)

                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
                            completed.append(new_beam)
                        else:
                            new_beams.append(new_beam)

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

        beam_search_output = RequestOutput(
            request_id=request_id,
            prompt=prompt,
            outputs=[
                CompletionOutput(
                    text=beam.text,
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens,
                    index=i,
                    logprobs=beam.cum_logprob,
                ) for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=tokenizedPrompt,
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
155
    def encode(
156
        self,
157
        prompt: PromptType,
158
159
160
161
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
162
        priority: int = 0,
163
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
164
        """Generate outputs for a request from an embedding model."""
165
        ...
166

167
    @abstractmethod
168
169
170
171
172
173
174
    async def abort(self, request_id: str) -> None:
        """Abort a request.

        Args:
            request_id: The unique id of the request.
        """

175
    @abstractmethod
176
177
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
178
        ...
179

180
    @abstractmethod
181
    async def get_decoding_config(self) -> DecodingConfig:
182
        ...
183
184
        """Get the decoding configuration of the vLLM engine."""

185
    @abstractmethod
186
187
188
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
189
190
191
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
192

193
    @abstractmethod
194
    async def is_tracing_enabled(self) -> bool:
195
        ...
196

197
    @abstractmethod
198
199
200
201
202
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
        model_output: Optional[List[SamplerOutput]] = None,
    ) -> None:
203
        ...
204

205
    @abstractmethod
206
207
    async def check_health(self) -> None:
        """Raise if unhealthy"""
208
        ...
209

210
    @abstractmethod
211
212
213
214
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

215
    @abstractmethod
216
217
218
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...