multi_step_worker.py 8.89 KB
Newer Older
1
import copy
2
import weakref
3
from typing import Dict, List, Tuple
4
5
6

import torch

7
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
8
                           SequenceGroupMetadata)
9
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
10
11
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeProposer)
12
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
13
from vllm.spec_decode.top1_proposer import Top1Proposer
14
from vllm.worker.worker import Worker
15
16


17
class MultiStepWorker(Worker, ProposerWorkerBase):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    """The MultiStepWorker is equivalent to a Worker except that it allows
    multiple forward passes in a single call, assuming the scheduler has
    allocated enough space to store the additional KV. This reduces overhead
    by invoking the scheduler less.

    The MultiStepWorker does not support cache swap operations, or beam search.
    Cache swap operations do not require large modifications. On the other hand,
    beam search requires memory allocations during sequence forks and thus
    requires more thought for MultiStepWorker support.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

32
        # Lazy initialization list.
33
        self._proposer: SpeculativeProposer
34

35
    def init_device(self) -> None:
36
        super().init_device()
37

38
        self._proposer = Top1Proposer(
39
            weakref.proxy(self),  # type: ignore[arg-type]
40
41
            self.device,
            self.vocab_size,
42
            max_proposal_len=self.max_model_len,
43
44
        )

45
    def set_include_gpu_probs_tensor(self) -> None:
46
47
48
        # Need include_gpu_probs_tensor for multi_step_worker
        self.model_runner.model.sampler.include_gpu_probs_tensor = True

49
    @torch.inference_mode()
50
    def sampler_output(
51
        self,
52
        execute_model_req: ExecuteModelRequest,
53
54
55
56
57
58
59
60
        sample_len: int,
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass sample_len times. Returns the list of
        sampler output, one per model forward pass, along with indicator of
        whether torch tensor in sampler output need to be transposed in latter
        sampler_output_to_torch logic.

        For multi step worker, this indicator shall be True.
61
        """
62
        self._raise_if_unsupported(execute_model_req)
63
64
65
66

        # Shallow copy input data so modifications (such as appending tokens)
        # do not cause side-effects.
        copied_seq_group_metadata_list = self._shallow_copy_inputs(
67
68
69
            execute_model_req.seq_group_metadata_list)
        copied_execute_model_req = execute_model_req.clone(
            copied_seq_group_metadata_list)
70

71
        # Run model sample_len times.
72
        model_outputs: List[SamplerOutput] = []
73
74
75
        if isinstance(self.model_runner, TP1DraftModelRunner):
            copied_execute_model_req.num_steps = sample_len
            model_outputs = self.execute_model(
76
                execute_model_req=copied_execute_model_req)
77
78
79
80
81
82
83
84
85
86
87
88
        else:
            # TODO: Remove this branch once DraftModelRunner supports TP>1.
            for _ in range(sample_len):
                model_output: List[SamplerOutput] = super().execute_model(
                    execute_model_req=copied_execute_model_req)
                assert (len(model_output) == 1
                        ), "composing multistep workers not supported"
                model_output = model_output[0]

                self._append_new_tokens(model_output,
                                        copied_seq_group_metadata_list)
                model_outputs.append(model_output)
89

90
        return model_outputs, True
91
92
93

    def get_spec_proposals(
        self,
94
        execute_model_req: ExecuteModelRequest,
95
96
97
98
99
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """

100
        return self._proposer.get_spec_proposals(execute_model_req)
101

102
    @staticmethod
103
    def _append_new_tokens(
104
105
            model_output: List[SamplerOutput],
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done outside of the worker, but it is
        required if the worker is to perform multiple forward passes.
        """
        for seq_group_metadata, sequence_group_outputs in zip(
                seq_group_metadata_list, model_output):
            seq_group_metadata.is_prompt = False

            for seq_output in sequence_group_outputs.samples:
                # NOTE: Beam search is not supported, so we can assume that
                # parent_seq_id == seq_id.
                seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

                token_id = seq_output.output_token
                token_logprob = seq_output.logprobs[token_id]

                seq.append_token_id(token_id, token_logprob.logprob)
123
                seq.update_num_computed_tokens(1)
124

125
    @staticmethod
126
    def _shallow_copy_inputs(
127
        seq_group_metadata_list: List[SequenceGroupMetadata]
128
129
130
131
132
133
134
135
136
137
138
    ) -> List[SequenceGroupMetadata]:
        """Copy input data structures to remove side-effects when input data
        structures are shared with other modules.

        Helpful when the vLLM scheduler runs in the same process as the worker.
        The alternative is deep-copying (or other form of deep copy); this has
        performance downsides.
        """

        # Shallow-copy the list of SequenceGroupMetadata. This allows us to
        # append tokens and change is_prompt without external side-effects.
139
        new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
140
141
142
143
144
145
146

        for old_seq_group_metadata in seq_group_metadata_list:
            # We must shallow-copy seq_group_metadata as is_prompt could change.
            seq_group_metadata = copy.copy(old_seq_group_metadata)
            new_seq_group_metadata_list.append(seq_group_metadata)

            # We must shallow-copy seq_data as we will append token ids
147
            new_seq_data: Dict[int, SequenceData] = {}
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
                new_seq_data[seq_id] = copy.copy(old_seq_data)
                new_seq_data[
                    seq_id].output_token_ids = old_seq_data.output_token_ids[:]

            seq_group_metadata.seq_data = new_seq_data

        return new_seq_group_metadata_list

    def _assert_enough_kv_space(
            self, seq_group_metadata_list: List[SequenceGroupMetadata],
            num_steps: int) -> None:
        """Assert there are enough physical blocks per sequence to store the
        current KV plus additional KV from num_steps tokens.
        """
        assert self.model_runner.block_size is not None
        for seq_group_metadata in seq_group_metadata_list:
            # Only one seq_id is guaranteed because there is no beam search.
            seq_id = list(seq_group_metadata.seq_data.keys())[0]
            seq = seq_group_metadata.seq_data[seq_id]

            # After num_steps, the seq len will be the current seq len
            # plus one token per step.
            final_seq_len = seq.get_len() + num_steps

            # We will have final_seq_len - 1 KV because vLLM saves KV for a
            # token in the iteration after the token was generated.
            required_num_kv_slots = final_seq_len - 1

            # The allocated number of kv slots is the number of allocated blocks
            # times the number of slots of block.
            number_physical_blocks = len(
                seq_group_metadata.block_tables[seq_id])
            allocated_kv_slots = (number_physical_blocks *
                                  self.model_runner.block_size)

            if required_num_kv_slots > allocated_kv_slots:
                request_id = seq_group_metadata.request_id
                raise ValueError(
                    "The worker attempted to run "
                    f"{num_steps} times but found insufficient KV space for "
                    f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
                    f"{required_num_kv_slots=}).")

    def _raise_if_unsupported(
        self,
194
        execute_model_req: ExecuteModelRequest,
195
196
197
198
    ) -> None:
        """MultiStepWorker does not yet implement support for cache swap
        operations or beam search.
        """
199
200
201
202
203
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
204
205
206
207
208
            raise NotImplementedError(
                "MultiStepWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
209
210
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
211
212
            raise NotImplementedError(
                "MultiStepWorker does not support beam search.")