"vscode:/vscode.git/clone" did not exist on "d697581a7c28668d00b2284477e20a2cd774ea6e"
parallel_sampling.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# SPDX-License-Identifier: Apache-2.0

from copy import copy
from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol,
                    Tuple, Union)

from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.utils import merge_async_iterators


class AsyncGenerateMethodType(Protocol):

    def __call__(self,
                 prompt: PromptType,
                 sampling_params: SamplingParams,
                 request_id: str,
                 lora_request: Optional[LoRARequest] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0) -> AsyncGenerator[RequestOutput, None]:
        ...


class SyncAddRequestMethodType(Protocol):

    def __call__(self,
                 request_id: str,
                 prompt: PromptType,
                 params: Union[SamplingParams, PoolingParams],
                 arrival_time: Optional[float] = None,
                 lora_request: Optional[LoRARequest] = None,
                 trace_headers: Optional[Mapping[str, str]] = None,
                 prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                 priority: int = 0) -> None:
        ...


class ParallelSamplingRequest:
    """Info, state & processing for parallel sampling request.
    
    Store parent request ID and sampling params.
    Facilitate generating child request sampling params.
    Transform child request outputs into parent request
    outputs.
    When stream mode is disabled, then `self.request_output`
    aggregates child request completions.
    """

    request_id: str
    sampling_params: SamplingParams
    cached_child_sampling_params: Optional[SamplingParams]
    request_output: Optional[RequestOutput]
    num_finished_completions: int

    def __init__(self, request_id: str,
                 sampling_params: SamplingParams) -> None:
        self.request_id = request_id
        self.sampling_params = sampling_params
        self.cached_child_sampling_params = None
        self.request_output = None
        self.num_finished_completions = 0

    def _get_child_sampling_params(
        self,
        index: int,
    ) -> SamplingParams:
        """Efficiently obtain child `sampling_params`

        If `sampling_params.seed` is not `None` then 
        each child request requires a unique clone of
        parent `sampling_params` with a unique seed.

        Args:
          index: index within `n` child requests

        Returns:
          Child `sampling_params` instance.
        """
        seed = self.sampling_params.seed
        if self.cached_child_sampling_params:
            # Reuse child sampling_params data structure
            return self.cached_child_sampling_params
        # Build child sampling_params
        child_sampling_params = copy(self.sampling_params)
        child_sampling_params.n = 1
        if seed is None:
            # Cache child sampling_params for later reuse
            self.cached_child_sampling_params = child_sampling_params
        else:
            # Each child gets a clone with a unique seed
            child_sampling_params.seed = seed + index
        return child_sampling_params

    def _add_output(
        self,
        child_req_output: RequestOutput,
        index: int,
    ) -> None:
        """Aggregate a parallel sampling child
        request output.
        
        Non-stream-mode (`output_kind == FINAL_ONLY`) 
        only. Inject correct parent request ID and
        completion index.

        Args:
          child_req_output: a single request output
                            from a parallel sampling
                            child request.   
          index: index within `n` child    
        """
        self.num_finished_completions += 1
        new_completion = child_req_output.outputs[0]
        new_completion.index = index
        if self.request_output is None:
            # Save the first request output; reinstate
            # original request ID; metrics are not
            # supported for parallel sampling
            child_req_output.request_id = self.request_id
            child_req_output.metrics = None
            self.request_output = child_req_output
        else:
            # Aggregate additional completion into request output
            # Note: will be sorted by index later
            self.request_output.outputs.append(new_completion)

    def _get_final_request_output(self) -> RequestOutput:
        """Invariant: parent completion outputs sorted by index"""
        assert self.request_output is not None
        self.request_output.finished = True
        self.request_output.outputs = sorted(self.request_output.outputs,
                                             key=lambda x: x.index)
        return self.request_output

    def get_child_info(self, index: int) -> Tuple[str, SamplingParams]:
        """Get child request ID and sampling params.
        
        Args:
          index: index within `n` child requests.
        
        Returns:
          (request ID, sampling_params) tuple
        """
        return (f"{index}_{self.request_id}",
                self._get_child_sampling_params(index))

    def process_output(
        self,
        child_req_output: RequestOutput,
        index: int,
    ) -> Optional[RequestOutput]:
        """Filter, aggregate and transform parallel sampling
        child request outputs.

        If the parent request has `stream=false`
        (`output_kind == FINAL_ONLY`), each child will also have
        `output_kind == FINAL_ONLY`. All child request outputs
        must be aggregated into a single request output, with
        multiple completions. This request output is only returned
        once `n` completions are aggregated.

        If the parent request has `stream=true`
        (`output_kind == DELTA`), each child will also have
        `output_kind == DELTA`. All child request outputs
        must be streamed directly to the caller.

        Args:
          child_req_output: a single child request output
          index: index within `n` child requests

        Returns:
          `None`, unless a processed request output is ready to
          send back to the caller.
        """
        if self.output_kind != RequestOutputKind.FINAL_ONLY:
            # stream=true: return child completions immediately
            child_req_output.request_id = self.request_id
            child_req_output.outputs[0].index = index
            if child_req_output.finished:
                # Parent request is complete if all child requests are
                # complete.
                self.num_finished_completions += 1
                child_req_output.finished = (
                    self.num_finished_completions == self.n)
            return child_req_output

        # stream=false: aggregate child completions
        self._add_output(child_req_output, index)
        if self.num_finished_completions == self.n:
            # Return aggregated request output after obtaining
            # all completions
            return self._get_final_request_output()
        return None

    async def wrap_child_async_generator(
        self,
        child_gen: AsyncGenerator[RequestOutput, None],
        index: int,
    ) -> AsyncGenerator[RequestOutput, None]:
        """Output generator for a single parallel sampling
        child request.

        Each parallel sampling request triggers at
        least two child requests. This generator
        yields zero or more request outputs to
        return to the caller, as they become
        available.

        Args:
          child_gen: generator for child request
                     outputs.
          index: index within the `n` child requests

        Returns:
          Yields zero or more request outputs to return
          to the caller.
        """
        async for out in child_gen:
            if req_out := self.process_output(out, index):
                yield req_out

    @property
    def n(self) -> int:
        return self.sampling_params.n

    @property
    def output_kind(self) -> RequestOutputKind:
        return self.sampling_params.output_kind


class SyncParallelSamplingManager:

    def __init__(self):
        # Parent req ID -> parent request manager
        self.parent_reqs: Dict[str, ParallelSamplingRequest] = {}
        # Child req ID -> (child req index, parent req ID)
        self.child_reqs: Dict[str, Tuple[int, str]] = {}

    def _register_parent_request(self, req: ParallelSamplingRequest) -> None:
        """Register parallel sampling parent request."""
        self.parent_reqs[req.request_id] = req

    def _register_child_request(self, req_id: str, child_req_id: str,
                                index: int) -> None:
        """Register parallel sampling child request with parent.
        
        Args:
          req_id: parent request ID
          child_req_id: child request ID
          index: child request index within `n` child requests
        """
        self.child_reqs[child_req_id] = (index, req_id)

    def get_num_unfinished_requests(self, num_core_reqs: int) -> int:
        """Get the number of unfinished requests, correcting for parallel
           sampling.
        
        Args:
          num_core_reqs: The number of unfinished requests in the engine core.
        
        Returns:
          Number of unfinished requests, where each parallel sampling req 
          counts as 1
        """
        return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs)

    def add_request_parallel_sampling(
        self,
        add_request: SyncAddRequestMethodType,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:
        """Add sync parallel sampling request."""
        req = ParallelSamplingRequest(request_id, params)
        self._register_parent_request(req)
        # Add n child requests with unique request IDs & random seeds and n=1
        for idx in range(req.n):
            child_req_id, child_params = req.get_child_info(idx)
            self._register_child_request(request_id, child_req_id, idx)
            add_request(request_id=child_req_id,
                        prompt=prompt,
                        params=child_params,
                        arrival_time=arrival_time,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        prompt_adapter_request=prompt_adapter_request,
                        priority=priority)  # type: ignore

    def step(
        self,
        outputs: List[RequestOutput],
    ) -> List[RequestOutput]:
        """Build parallel sampling request outputs.
        
        Extract child request outputs, aggregate them
        into parent request output, and return parent
        output when complete.

        Do not modify `n=1` requests.

        Args:
          outputs: step request outputs. Mix of child request
                   outputs & `n=1` request outputs.

        Return:
          List of parallel sampling parent request outputs &
          unmodified `n=1` request outputs passed-thru from input.
        """
        if not (self.parent_reqs and outputs):
            # Return unmodified
            return outputs
        agg_outputs = []
        for output in outputs:
            req_id = output.request_id
            if child_req_entry := self.child_reqs.get(req_id, None):
                # For each parallel sampling child request output:
                (index, parent_req_id) = child_req_entry
                req = self.parent_reqs[parent_req_id]
                # Update parallel sampling request
                if out := req.process_output(output, index):
                    # Return parent request output if complete;
                    # cleanup parent request bookkeeping.
                    agg_outputs.append(out)
                    del self.parent_reqs[parent_req_id]
                # Cleanup child request bookkeeping.
                del self.child_reqs[req_id]
            else:
                # Not a parallel sampling request output
                agg_outputs.append(output)
        return agg_outputs


async def generate_parallel_sampling_async(
    generate: AsyncGenerateMethodType,
    prompt: PromptType,
    sampling_params: SamplingParams,
    request_id: str,
    lora_request: Optional[LoRARequest] = None,
    trace_headers: Optional[Mapping[str, str]] = None,
    prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
    """Generate completions for async parallel sampling requests."""
    parent_req = ParallelSamplingRequest(request_id, sampling_params)

    # Aggregate generators for n child requests
    gens: List[AsyncGenerator[RequestOutput, None]] = []
    for idx in range(parent_req.n):
        child_req_id, child_params = parent_req.get_child_info(idx)
        child_gen = generate(
            prompt=prompt,
            sampling_params=child_params,
            request_id=child_req_id,
            lora_request=lora_request,
            trace_headers=trace_headers,
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )  # type: ignore
        gen = parent_req.wrap_child_async_generator(child_gen, idx)
        gens.append(gen)

    # Merge generators
    async for _, out in merge_async_iterators(*gens):
        yield out