"vscode:/vscode.git/clone" did not exist on "9e5ec696fb40ba65f6e0f618c3192c1ac158378b"
multi_step_worker.py 17.6 KB
Newer Older
1
import copy
2
import weakref
3
from typing import Dict, List, Set, Tuple
4
5
6

import torch

7
from vllm.model_executor.layers.sampler import SamplerOutput
8
from vllm.platforms import current_platform
9
10
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
                           SequenceGroupMetadata)
11
12
13
14

if current_platform.is_cuda_alike():
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

15
16
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeProposer)
17
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
18
from vllm.spec_decode.top1_proposer import Top1Proposer
19
from vllm.worker.worker_base import DelegateWorkerBase
20
21


22
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
23
24
25
26
27
28
29
30
31
32
33
34
    """The MultiStepWorker is equivalent to a Worker except that it allows
    multiple forward passes in a single call, assuming the scheduler has
    allocated enough space to store the additional KV. This reduces overhead
    by invoking the scheduler less.

    The MultiStepWorker does not support cache swap operations, or beam search.
    Cache swap operations do not require large modifications. On the other hand,
    beam search requires memory allocations during sequence forks and thus
    requires more thought for MultiStepWorker support.
    """

    def __init__(self, *args, **kwargs):
35
        DelegateWorkerBase.__init__(self, *args, **kwargs)
36
        # Lazy initialization list.
37
        self._proposer: SpeculativeProposer
38

39
    def init_device(self) -> None:
40
        self.worker.init_device()
41
        self._proposer = Top1Proposer(
42
            weakref.proxy(self),  # type: ignore[arg-type]
43
44
            self.device,
            self.vocab_size,
45
            max_proposal_len=self.max_model_len,
46
47
        )

48
    def set_include_gpu_probs_tensor(self) -> None:
49
        # Need include_gpu_probs_tensor for MultiStepWorker
50
51
        self.model_runner.model.sampler.include_gpu_probs_tensor = True

52
53
54
55
    def set_should_modify_greedy_probs_inplace(self) -> None:
        self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
            True)

56
    @torch.inference_mode()
57
    def sampler_output(
58
        self,
59
        execute_model_req: ExecuteModelRequest,
60
        sample_len: int,
61
        seq_ids_with_bonus_token_in_last_step: Set[int],
62
63
64
65
66
67
68
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass sample_len times. Returns the list of
        sampler output, one per model forward pass, along with indicator of
        whether torch tensor in sampler output need to be transposed in latter
        sampler_output_to_torch logic.

        For multi step worker, this indicator shall be True.
69
        """
70
        self._raise_if_unsupported(execute_model_req)
71
72
73
74
75
76
        # Expand the batch for sequences with a bonus token.
        # Perform a forward pass on the expanded batch and filter the
        # response to retain only the original sequences' responses.
        expanded_request, indices_of_seq_with_bonus_tokens =\
            self._expand_execute_model_request(
                execute_model_req, seq_ids_with_bonus_token_in_last_step)
77

78
        # Run model sample_len times.
79
        model_outputs: List[SamplerOutput] = []
80
        if current_platform.is_cuda_alike() and isinstance(
81
82
83
84
                self.model_runner, TP1DraftModelRunner
        ) and self.model_runner.supports_gpu_multi_step(expanded_request):
            # Here we run the draft_model_runner with multi-step prepare
            # on the GPU directly
85
            expanded_request.num_steps = sample_len
86
87
            self.model_runner.set_indices_of_seq_with_bonus_tokens(
                indices_of_seq_with_bonus_tokens)
88
            model_outputs = self.execute_model(
89
                execute_model_req=expanded_request)
90
        else:
91
92
93
94
95
            # Here we run multi-step directly, with every step prepared
            # on the CPU.
            # TODO: Remove this branch once DraftModelRunner supports TP>1
            # and other restrictions that are part of DraftModelRunner's
            # supports_gpu_multi_step(..)
96
            for _ in range(sample_len):
97
                model_output: List[SamplerOutput] = self.worker.execute_model(
98
                    execute_model_req=expanded_request)
99
100
101
102
                assert (len(model_output) == 1
                        ), "composing multistep workers not supported"
                model_output = model_output[0]

103
                self._append_new_tokens(
104
105
                    model_output, expanded_request.seq_group_metadata_list,
                    indices_of_seq_with_bonus_tokens)
106
                model_outputs.append(model_output)
107

108
109
110
        # move indices to device to avoid stream sync
        indices_of_seq_with_bonus_tokens = torch.tensor(
            indices_of_seq_with_bonus_tokens, device=self.device)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        filtered_model_outputs = self._filter_model_output(
            model_outputs, indices_of_seq_with_bonus_tokens)
        return filtered_model_outputs, True

    @staticmethod
    def _expand_execute_model_request(
        execute_model_req: ExecuteModelRequest,
        seq_with_bonus_token_in_last_step: set,
    ) -> Tuple[ExecuteModelRequest, List[int]]:
        """
        Expands the execute model request based on sequences with bonus
        tokens.

        For each sequence with a bonus token, this method creates a new
        sequence without the bonus token and adds it to the execute model
        request. The original sequence groups are also retained. The indices
        of the original sequence groups are returned for further processing.

        Args:
            execute_model_req (ExecuteModelRequest): The original execute
            model request.
            seq_with_bonus_token_in_last_step (set): Set of sequence IDs that 
            contain bonus tokens.

        Returns:
            Tuple[ExecuteModelRequest, List[int]]: The updated execute model
            request with expanded sequences and a list of indices corresponding
            to the original sequence groups.
        """
        updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
        updated_execute_model_req = execute_model_req.clone(
            updated_seq_group_metadata_list)
        indices_of_original_sequence_groups = []
        for seq_group in execute_model_req.seq_group_metadata_list:
            seq_group_has_bonus_tokens = False
            for seq_id, _ in seq_group.seq_data.items():
                # Identify sequences with bonus tokens in the sequence group.
                if seq_id in seq_with_bonus_token_in_last_step:
                    seq_group_has_bonus_tokens = True
                    break
            if seq_group_has_bonus_tokens:
                #Create new sequences without the last bonus token. These new
                # sequence have the same sequence id as the original sequence.
                # We create a new sequence group and add them there.
                updated_seq_group_without_bonus_token  = \
                    MultiStepWorker._copy_seq_metadata_excluding_last_token(
                        seq_group, seq_with_bonus_token_in_last_step)
                updated_seq_group_metadata_list.append(
                    updated_seq_group_without_bonus_token)
            # Add the original sequence group.
            updated_seq_group_metadata_list.append(
                MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
            # Record the index of the original sequence group.
            indices_of_original_sequence_groups.append(
                len(updated_seq_group_metadata_list) - 1)

        updated_execute_model_req.seq_group_metadata_list =\
            updated_seq_group_metadata_list
169
170
171
172
173
174

        if isinstance(updated_execute_model_req.previous_hidden_states,
                      HiddenStates):
            updated_execute_model_req.previous_hidden_states\
                .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)

175
176
177
178
179
        return updated_execute_model_req, indices_of_original_sequence_groups

    @staticmethod
    def _filter_model_output(
            expanded_batch_outputs: List[SamplerOutput],
180
            output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
181
182
183
184
185
186
187
188
189
        """
        Filters the model output to include only the specified sequence
        outputs. This method contracts the expanded batch output from the
        model to retain the outputs of only those sequences indicated by the
        provided indices.

        Args:
            expanded_batch_output (List[SamplerOutput]): The expanded output
                batch from the model.
190
191
            output_indices_to_retain (torch.Tensor): Indices of the model
                outputs to retain.
192
193
194
195
196
197
198
199
200
201

        Returns:
            List[SamplerOutput]: A list containing the filtered model 
            outputs for the specified indices.
        """
        return [
            SamplerOutput(
                outputs=[
                    expanded_batch_output.outputs[i]
                    for i in output_indices_to_retain
202
                ] if len(expanded_batch_output.outputs) > 0 else [],
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                sampled_token_probs=(
                    expanded_batch_output.
                    sampled_token_probs[output_indices_to_retain]
                    if expanded_batch_output.sampled_token_probs is not None
                    else None),
                logprobs=(
                    expanded_batch_output.logprobs[output_indices_to_retain]
                    if expanded_batch_output.logprobs is not None else None),
                sampled_token_ids=(expanded_batch_output.
                                   sampled_token_ids[output_indices_to_retain]
                                   if expanded_batch_output.sampled_token_ids
                                   is not None else None))
            for expanded_batch_output in expanded_batch_outputs
        ]
217
218
219

    def get_spec_proposals(
        self,
220
        execute_model_req: ExecuteModelRequest,
221
        seq_ids_with_bonus_token_in_last_step: set,
222
223
224
225
    ) -> SpeculativeProposals:
        """Produce speculations given an input batch of sequences. The number of
        speculative tokens per sequence is determined by max_proposal_len.
        """
226
227
        return self._proposer.get_spec_proposals(
            execute_model_req, seq_ids_with_bonus_token_in_last_step)
228

229
    @staticmethod
230
    def _append_new_tokens(
231
            model_output: List[SamplerOutput],
232
233
            seq_group_metadata_list: List[SequenceGroupMetadata],
            indices_of_seq_with_bonus_tokens: List[int]) -> None:
234
235
236
237
        """Given model output from a single run, append the tokens to the
        sequences. This is normally done outside of the worker, but it is
        required if the worker is to perform multiple forward passes.
        """
238
239
240
        count = 0
        for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
                zip(seq_group_metadata_list, model_output)):
241
242
243
244
245
246
247
248
249
            seq_group_metadata.is_prompt = False

            for seq_output in sequence_group_outputs.samples:
                # NOTE: Beam search is not supported, so we can assume that
                # parent_seq_id == seq_id.
                seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

                token_id = seq_output.output_token
                token_logprob = seq_output.logprobs[token_id]
250
251
252
253
254
255
256
257
258
259
                # Determine the actual token ID to be generated,
                # considering bonus tokens
                if index != indices_of_seq_with_bonus_tokens[count]:
                    bonus_seq_metadata = seq_group_metadata_list[
                        indices_of_seq_with_bonus_tokens[count]]
                    _, bonus_token_seq_data = next(
                        iter(bonus_seq_metadata.seq_data.items()))
                    token_id = bonus_token_seq_data.output_token_ids[-1]
                else:
                    count += 1
260
261

                seq.append_token_id(token_id, token_logprob.logprob)
262
                seq.update_num_computed_tokens(1)
263

264
    @staticmethod
265
266
    def _shallow_copy_seq_group_metadata(
        seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
267
268
269
270
271
272
273
        """Copy input data structures to remove side-effects when input data
        structures are shared with other modules.

        Helpful when the vLLM scheduler runs in the same process as the worker.
        The alternative is deep-copying (or other form of deep copy); this has
        performance downsides.
        """
274
        # Shallow-copy the SequenceGroupMetadata. This allows us to
275
        # append tokens and change is_prompt without external side-effects.
276
277
        # We must shallow-copy seq_group_metadata as is_prompt could change.
        new_seq_group_metadata = copy.copy(seq_group_metadata)
278

279
280
281
282
283
284
        # We must shallow-copy seq_data as we will append token ids
        new_seq_data: Dict[int, SequenceData] = {}
        for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
            new_seq_data[seq_id] = copy.copy(old_seq_data)
            new_seq_data[seq_id].output_token_ids =\
                old_seq_data.output_token_ids[:]
285

286
287
        new_seq_group_metadata.seq_data = new_seq_data
        return new_seq_group_metadata
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    @staticmethod
    def _copy_seq_metadata_excluding_last_token(
        seq_group_metadata: SequenceGroupMetadata,
        seq_ids_to_copy: Set[int],
    ) -> SequenceGroupMetadata:
        """
        Creates a shallow copy of the given SequenceGroupMetadata, retaining
        only the sequence IDs specified in seq_ids_to_copy. For each of these
        sequence IDs, all output_token_ids except the last one are copied.
        Sequence IDs not in seq_ids_to_copy are excluded from the copy.
        
        Parameters:
        seq_group_metadata (SequenceGroupMetadata): The original sequence
            group metadata.
        seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
            copy.
        
        Returns:
        SequenceGroupMetadata: A shallow copy of the sequence group metadata
            with the specified modifications.
        """
        # Shallow-copy the SequenceGroupMetadata.
        new_seq_group_metadata = copy.copy(seq_group_metadata)
        # Shallow-copy seq_data and modify the output_token_ids.
        new_seq_data: Dict[int, SequenceData] = {}
        for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
            if (seq_id in seq_ids_to_copy):
                new_seq_data[seq_id] = copy.copy(old_seq_data)
                # Copy all the output token ids except the last.
                # Also reduce num_computed_tokens by 1 since we are not
                # including the last output token.
                # NOTE: num_computed_tokens is not directly used by the
                # speculative decoding workers, as it is only relevant for
                # chunked prefill, which is disabled for speculative decoding.
                # However, to maintain consistency in num_computed_tokens,
                # we update it here.
                new_seq_data[seq_id].output_token_ids =\
                    old_seq_data.output_token_ids[:-1]
                new_seq_data[seq_id].update_num_computed_tokens(-1)
        new_seq_group_metadata.seq_data = new_seq_data
        return new_seq_group_metadata
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367

    def _assert_enough_kv_space(
            self, seq_group_metadata_list: List[SequenceGroupMetadata],
            num_steps: int) -> None:
        """Assert there are enough physical blocks per sequence to store the
        current KV plus additional KV from num_steps tokens.
        """
        assert self.model_runner.block_size is not None
        for seq_group_metadata in seq_group_metadata_list:
            # Only one seq_id is guaranteed because there is no beam search.
            seq_id = list(seq_group_metadata.seq_data.keys())[0]
            seq = seq_group_metadata.seq_data[seq_id]

            # After num_steps, the seq len will be the current seq len
            # plus one token per step.
            final_seq_len = seq.get_len() + num_steps

            # We will have final_seq_len - 1 KV because vLLM saves KV for a
            # token in the iteration after the token was generated.
            required_num_kv_slots = final_seq_len - 1

            # The allocated number of kv slots is the number of allocated blocks
            # times the number of slots of block.
            number_physical_blocks = len(
                seq_group_metadata.block_tables[seq_id])
            allocated_kv_slots = (number_physical_blocks *
                                  self.model_runner.block_size)

            if required_num_kv_slots > allocated_kv_slots:
                request_id = seq_group_metadata.request_id
                raise ValueError(
                    "The worker attempted to run "
                    f"{num_steps} times but found insufficient KV space for "
                    f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
                    f"{required_num_kv_slots=}).")

    def _raise_if_unsupported(
        self,
368
        execute_model_req: ExecuteModelRequest,
369
370
371
372
    ) -> None:
        """MultiStepWorker does not yet implement support for cache swap
        operations or beam search.
        """
373
374
375
376
377
        if any([
                execute_model_req.blocks_to_swap_in,
                execute_model_req.blocks_to_swap_out,
                execute_model_req.blocks_to_copy
        ]):
378
379
380
381
382
            raise NotImplementedError(
                "MultiStepWorker does not support cache operations")

        if any(
                len(seq_group_metadata.seq_data.keys()) != 1
383
384
                for seq_group_metadata in
                execute_model_req.seq_group_metadata_list):
385
386
            raise NotImplementedError(
                "MultiStepWorker does not support beam search.")