parallel_sampling.py 4.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from copy import copy
5
from typing import Optional
6

7
8
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
9
from vllm.v1.metrics.stats import IterationStats
10
11


12
class ParentRequest:
13
    """Info, state & processing for parallel sampling request.
14

15
16
17
18
19
20
    Store parent request ID and sampling params.
    Facilitate generating child request sampling params.
    """

    request_id: str
    sampling_params: SamplingParams
21

22
23
24
    # To track the completion of child requests
    child_requests: set[str]

25
    # To aggregate child completions when not streaming
26
    output_aggregator: list[CompletionOutput]
27

28
29
30
    # To find the max number of generated tokens across all children
    max_num_generation_tokens: int

31
    # To efficiently obtain child sampling params
32
33
34
35
36
37
    cached_child_sampling_params: Optional[SamplingParams]

    def __init__(self, request_id: str,
                 sampling_params: SamplingParams) -> None:
        self.request_id = request_id
        self.sampling_params = sampling_params
38

39
        self.child_requests = set()
40
41
42
        self.output_aggregator = [None] * sampling_params.n if (
            sampling_params.output_kind
            == RequestOutputKind.FINAL_ONLY) else []
43
        self.max_num_generation_tokens = 0
44
        self.cached_child_sampling_params = None
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def _get_child_sampling_params(
        self,
        index: int,
    ) -> SamplingParams:
        """Efficiently obtain child `sampling_params`

        If `sampling_params.seed` is not `None` then 
        each child request requires a unique clone of
        parent `sampling_params` with a unique seed.

        Args:
          index: index within `n` child requests

        Returns:
          Child `sampling_params` instance.
        """
        seed = self.sampling_params.seed
        if self.cached_child_sampling_params:
            # Reuse child sampling_params data structure
            return self.cached_child_sampling_params
        # Build child sampling_params
        child_sampling_params = copy(self.sampling_params)
        child_sampling_params.n = 1
        if seed is None:
            # Cache child sampling_params for later reuse
            self.cached_child_sampling_params = child_sampling_params
        else:
            # Each child gets a clone with a unique seed
            child_sampling_params.seed = seed + index
        return child_sampling_params

77
    def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
78
79
80
81
82
83
84
85
        """Get child request ID and sampling params.
        
        Args:
          index: index within `n` child requests.
        
        Returns:
          (request ID, sampling_params) tuple
        """
86
87
        child_req_id = f"{index}_{self.request_id}"
        self.child_requests.add(child_req_id)
88
        return child_req_id, self._get_child_sampling_params(index)
89
90
91
92
93

    @property
    def n(self) -> int:
        return self.sampling_params.n

94
    def get_outputs(
95
        self,
96
        child_request_id: str,
97
        completion_output: CompletionOutput,
98
99
100
    ) -> tuple[str, list[CompletionOutput], bool]:
        if completion_output.finished():
            self.child_requests.remove(child_request_id)
101

102
103
104
105
106
107
108
        if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
            # If streaming, just return the current output.
            outputs = [completion_output]
        else:
            # If not streaming, aggregate the n final outputs.
            self.output_aggregator[completion_output.index] = completion_output
            outputs = [] if self.child_requests else self.output_aggregator
109

110
111
        finished = not self.child_requests
        return self.request_id, outputs, finished
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    def observe_num_generation_tokens(self, num_generation_tokens: int):
        self.max_num_generation_tokens = max(num_generation_tokens,
                                             self.max_num_generation_tokens)
        return self.max_num_generation_tokens

    @staticmethod
    def observe_finished_request(parent_req: Optional['ParentRequest'],
                                 iteration_stats: IterationStats,
                                 num_generation_tokens: int):

        n_param = parent_req.n if parent_req is not None else 1

        if parent_req is not None:
            num_generation_tokens = parent_req.observe_num_generation_tokens(
                num_generation_tokens)

        # Child requests finished, we can now record to iteration stats
        if parent_req is None or not parent_req.child_requests:
            iteration_stats.max_num_generation_tokens_iter.append(
                num_generation_tokens)
            iteration_stats.n_params_iter.append(n_param)