"vllm/model_executor/models/opt.py" did not exist on "667ba3995c013df060657a4cdf3931176c6c5259"
parallel_sampling.py 5.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from copy import copy
5
from typing import cast
6

7
8
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
9
from vllm.v1.engine import EngineCoreRequest
10
from vllm.v1.metrics.stats import IterationStats
11
12


13
class ParentRequest:
14
    """Info, state & processing for parallel sampling request.
15

16
17
18
19
20
    Store parent request ID and sampling params.
    Facilitate generating child request sampling params.
    """

    request_id: str
21
    external_req_id: str
22
    sampling_params: SamplingParams
23

24
25
26
    # To track the completion of child requests
    child_requests: set[str]

27
    # To aggregate child completions when not streaming
28
    output_aggregator: list[CompletionOutput]
29

30
31
32
    # To find the max number of generated tokens across all children
    max_num_generation_tokens: int

33
    # To efficiently obtain child sampling params
34
    cached_child_sampling_params: SamplingParams | None
35

36
37
38
39
40
    def __init__(self, request: EngineCoreRequest) -> None:
        assert request.external_req_id is not None
        sampling_params = request.params
        self.request_id = request.request_id
        self.external_req_id = request.external_req_id
41
        self.sampling_params = sampling_params
42

43
        self.child_requests = set()
44
        self.output_aggregator = (
45
            [cast(CompletionOutput, None)] * sampling_params.n
46
47
48
            if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
            else []
        )
49
        self.max_num_generation_tokens = 0
50
        self.cached_child_sampling_params = None
51

52
53
54
55
56
57
    def _get_child_sampling_params(
        self,
        index: int,
    ) -> SamplingParams:
        """Efficiently obtain child `sampling_params`

58
        If `sampling_params.seed` is not `None` then
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        each child request requires a unique clone of
        parent `sampling_params` with a unique seed.

        Args:
          index: index within `n` child requests

        Returns:
          Child `sampling_params` instance.
        """
        seed = self.sampling_params.seed
        if self.cached_child_sampling_params:
            # Reuse child sampling_params data structure
            return self.cached_child_sampling_params
        # Build child sampling_params
        child_sampling_params = copy(self.sampling_params)
        child_sampling_params.n = 1
        if seed is None:
            # Cache child sampling_params for later reuse
            self.cached_child_sampling_params = child_sampling_params
        else:
            # Each child gets a clone with a unique seed
            child_sampling_params.seed = seed + index
        return child_sampling_params

83
    def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
84
        """Get child request ID and sampling params.
85

86
87
        Args:
          index: index within `n` child requests.
88

89
90
91
        Returns:
          (request ID, sampling_params) tuple
        """
92
93
        child_req_id = f"{index}_{self.request_id}"
        self.child_requests.add(child_req_id)
94
        return child_req_id, self._get_child_sampling_params(index)
95
96
97
98
99

    @property
    def n(self) -> int:
        return self.sampling_params.n

100
    def get_outputs(
101
        self,
102
        child_request_id: str,
103
        completion_output: CompletionOutput,
104
    ) -> tuple[list[CompletionOutput], bool]:
105
        already_finished_and_returned: bool = False
106
        if completion_output.finished():
107
108
109
110
111
112
113
            if child_request_id in self.child_requests:
                self.child_requests.remove(child_request_id)
            else:
                # child request ID is not available in child_requests
                # which means the request had finished in previous
                # batch step and returned to the client earlier
                already_finished_and_returned = True
114

115
        if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
116
117
118
119
            # If streaming, just return the current output
            #
            # DO NOT output finished and already returned child request to client again
            outputs = [] if already_finished_and_returned else [completion_output]
120
121
122
123
        else:
            # If not streaming, aggregate the n final outputs.
            self.output_aggregator[completion_output.index] = completion_output
            outputs = [] if self.child_requests else self.output_aggregator
124

125
        finished = not self.child_requests
126
        return outputs, finished
127
128

    def observe_num_generation_tokens(self, num_generation_tokens: int):
129
130
131
        self.max_num_generation_tokens = max(
            num_generation_tokens, self.max_num_generation_tokens
        )
132
133
134
        return self.max_num_generation_tokens

    @staticmethod
135
    def observe_finished_request(
136
        parent_req: "ParentRequest | None",
137
138
139
        iteration_stats: IterationStats,
        num_generation_tokens: int,
    ):
140
141
142
143
        n_param = parent_req.n if parent_req is not None else 1

        if parent_req is not None:
            num_generation_tokens = parent_req.observe_num_generation_tokens(
144
145
                num_generation_tokens
            )
146
147
148

        # Child requests finished, we can now record to iteration stats
        if parent_req is None or not parent_req.child_requests:
149
            iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens)
150
            iteration_stats.n_params_iter.append(n_param)