parallel_sampling.py 5.26 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0

from copy import copy
4
from typing import Callable, Optional, Union
5

6
from vllm.outputs import CompletionOutput, RequestOutput
7
from vllm.pooling_params import PoolingParams
8
from vllm.sampling_params import SamplingParams
9
from vllm.v1.metrics.stats import IterationStats
10
11


12
class ParentRequest:
13
    """Info, state & processing for parallel sampling request.
14

15
16
17
18
19
20
    Store parent request ID and sampling params.
    Facilitate generating child request sampling params.
    """

    request_id: str
    sampling_params: SamplingParams
21

22
23
24
    # To track the completion of child requests
    child_requests: set[str]

25
26
27
    # To aggregate child completions when not streaming
    output_aggregator: Optional[RequestOutput]

28
29
30
    # To find the max number of generated tokens across all children
    max_num_generation_tokens: int

31
    # To efficiently obtain child sampling params
32
33
34
35
36
37
    cached_child_sampling_params: Optional[SamplingParams]

    def __init__(self, request_id: str,
                 sampling_params: SamplingParams) -> None:
        self.request_id = request_id
        self.sampling_params = sampling_params
38

39
        self.child_requests = set()
40
        self.output_aggregator = None
41
        self.max_num_generation_tokens = 0
42
        self.cached_child_sampling_params = None
43
44
45
46
47
48
49
50
51
52

    @classmethod
    def from_params(
        cls,
        request_id: str,
        params: Union[SamplingParams, PoolingParams],
    ) -> Optional['ParentRequest']:
        if not isinstance(params, SamplingParams) or params.n == 1:
            return None
        return cls(request_id, params)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    def _get_child_sampling_params(
        self,
        index: int,
    ) -> SamplingParams:
        """Efficiently obtain child `sampling_params`

        If `sampling_params.seed` is not `None` then 
        each child request requires a unique clone of
        parent `sampling_params` with a unique seed.

        Args:
          index: index within `n` child requests

        Returns:
          Child `sampling_params` instance.
        """
        seed = self.sampling_params.seed
        if self.cached_child_sampling_params:
            # Reuse child sampling_params data structure
            return self.cached_child_sampling_params
        # Build child sampling_params
        child_sampling_params = copy(self.sampling_params)
        child_sampling_params.n = 1
        if seed is None:
            # Cache child sampling_params for later reuse
            self.cached_child_sampling_params = child_sampling_params
        else:
            # Each child gets a clone with a unique seed
            child_sampling_params.seed = seed + index
        return child_sampling_params

85
    def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
86
87
88
89
90
91
92
93
        """Get child request ID and sampling params.
        
        Args:
          index: index within `n` child requests.
        
        Returns:
          (request ID, sampling_params) tuple
        """
94
95
96
97
98
99
        child_req_id = f"{index}_{self.request_id}"
        self.child_requests.add(child_req_id)
        return (child_req_id, self._get_child_sampling_params(index))

    def finish_child_request(self, req_id: str):
        self.child_requests.remove(req_id)
100
101
102
103
104

    @property
    def n(self) -> int:
        return self.sampling_params.n

105
    def make_request_output(
106
        self,
107
108
109
110
111
112
        final_only: bool,
        completion_output: CompletionOutput,
        new_request_output: Callable[[str], RequestOutput],
    ) -> Optional[RequestOutput]:
        # Use an existing RequestOutput if we're aggregating
        request_output = self.output_aggregator
113

114
115
116
        # Make new RequestOutput otherwise
        if request_output is None:
            request_output = new_request_output(self.request_id)
117

118
119
        # Add a new completion
        request_output.outputs.append(completion_output)
120

121
122
123
124
        # If not streaming, aggregate until all child requests complete
        if final_only and len(request_output.outputs) != self.n:
            self.output_aggregator = request_output
            return None
125

126
127
        # We're done aggregating
        self.output_aggregator = None
128

129
130
131
132
        # Parent completion output list must be sorted by index
        request_output.outputs = sorted(request_output.outputs,
                                        key=lambda x: x.index)
        return request_output
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def observe_num_generation_tokens(self, num_generation_tokens: int):
        self.max_num_generation_tokens = max(num_generation_tokens,
                                             self.max_num_generation_tokens)
        return self.max_num_generation_tokens

    @staticmethod
    def observe_finished_request(parent_req: Optional['ParentRequest'],
                                 iteration_stats: IterationStats,
                                 num_generation_tokens: int):

        n_param = parent_req.n if parent_req is not None else 1

        if parent_req is not None:
            num_generation_tokens = parent_req.observe_num_generation_tokens(
                num_generation_tokens)

        # Child requests finished, we can now record to iteration stats
        if parent_req is None or not parent_req.child_requests:
            iteration_stats.max_num_generation_tokens_iter.append(
                num_generation_tokens)
            iteration_stats.n_params_iter.append(n_param)