draft_model_runner.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
from typing import List, Optional

import torch

8
from vllm.forward_context import set_forward_context
9
from vllm.model_executor.layers.sampler import SamplerOutput
10
11

try:
12
13
14
15
16
17
18
19
    try:
        from vllm.attention.backends.flash_attn import FlashAttentionMetadata
    except (ModuleNotFoundError, ImportError):
        # vllm_flash_attn is not installed, try the ROCm FA metadata
        from vllm.attention.backends.rocm_flash_attn import (
            ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
    raise RuntimeError(
20
        "Draft model speculative decoding currently only supports "
21
        "CUDA and ROCm flash attention backend.") from err
22

23
from vllm.logger import init_logger
24
from vllm.multimodal import MultiModalKwargs
25
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
26
27
28
from vllm.worker.model_runner_base import (ModelRunnerBase,
                                           ModelRunnerInputBase,
                                           ModelRunnerWrapperBase)
29
30
31

logger = init_logger(__name__)

32
33
# A flag to enable debug prints for the updated input tensors
# before each step.
34
debug_advance_input = False
35
36
37
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
38

39

40
class TP1DraftModelRunner(ModelRunnerWrapperBase):
41
42
43
44
45
46
    """Specialized model runner for speculative decoding draft model.
    Since the draft model always execute k forward passes consecutively to
    generate k speculative tokens in a single speculative decoding step,
    we could get rid of most CPU-GPU synchronization and data transfer
    overheads by keeping model input and output tensors on GPU all the time.

47
48
49
    TODOs:
    1. Currently supports only flash-attn, add support for other attn_backends.
    2. Support TP > 1 (this requires some designs because we do not expect
50
51
52
       any broadcasting inside execute_model).
    """

53
54
    def __init__(self, model_runner: ModelRunnerBase):
        super().__init__(model_runner)
55

56
57
        self.indices_of_seq_with_bonus_tokens = None

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def _update_sampling_metadata(self, sampling_metadata, num_seqs,
                                  num_queries):

        assert sampling_metadata.num_prompts == 0
        assert len(sampling_metadata.seq_groups) == num_queries
        assert sampling_metadata.selected_token_indices.shape == (
            num_queries, )
        # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501

        # Verify that all sequences are decodes
        for i in range(num_queries):
            seq_group = sampling_metadata.seq_groups[i]

            assert seq_group.is_prompt is False  # No prompt
            assert seq_group.prompt_logprob_indices == []  # No prompt
            assert seq_group.sample_indices == [i]  # Simple

75
76
    def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
                          last_output: SamplerOutput) -> ModelRunnerInputBase:
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        # Currently, we expect "decode mode" only
        assert not model_input.is_prompt

        # Get num_seqs
        num_seqs = len(model_input.seq_lens)
        num_queries = len(model_input.query_lens)

        # Get output tokens GPU tensor
        sampled_token_ids = last_output.sampled_token_ids
        assert sampled_token_ids is not None

        # Update attn_metadata
        attn_metadata = model_input.attn_metadata
        assert isinstance(attn_metadata, FlashAttentionMetadata)
91
92
93

        attn_metadata.advance_step(model_input, sampled_token_ids,
                                   self.block_size, num_seqs, num_queries)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

        # Update sampling_metadata
        sampling_metadata = model_input.sampling_metadata
        self._update_sampling_metadata(sampling_metadata, num_seqs,
                                       num_queries)

        # Create new input
        new_model_input = self._model_input_cls(
            input_tokens=model_input.input_tokens,
            input_positions=model_input.input_positions,
            attn_metadata=attn_metadata,
            seq_lens=attn_metadata.seq_lens,
            query_lens=model_input.query_lens,
            lora_mapping=model_input.lora_mapping,
            lora_requests=model_input.lora_requests,
            multi_modal_kwargs=model_input.multi_modal_kwargs,
            sampling_metadata=model_input.sampling_metadata,
            is_prompt=False,
        )

        # Ensure we skip CPU samples
        assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
        # We can reuse sampling tensors since every decode iteration is the same
        new_model_input.sampling_metadata.reuse_sampling_tensors = True

        if debug_advance_input:
            logger.debug("NEW INPUT: ")
            logger.debug("  input_tokens = %s", new_model_input.input_tokens)
            logger.debug("  input_positions = %s",
                         new_model_input.input_positions)
            logger.debug("  seq_lens = %d", new_model_input.seq_lens)
            logger.debug("  query_lens = %d", new_model_input.query_lens)
            logger.debug("  attn_metadata:")
            logger.debug("    seq_lens_tensor: %s",
                         attn_metadata.seq_lens_tensor)
            logger.debug("    slot_mapping: %s", attn_metadata.slot_mapping)
            logger.debug("    block_tables: %s", attn_metadata.block_tables)

        return new_model_input

    def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
        """Determines if draft_model_runner GPU multi-step can be used.
        Currently required conditions are:
137
            1. Only decodes
138
139
140
            2. Only flash-attn
            3. No LORA
            4. No prompt_adapter_config
141
        """
142
        if not allow_gpu_advance_step:
143
            return False
144

145
146
147
148
        # We allow multi-step GPU only in decode mode
        for seq_group in execute_model_req.seq_group_metadata_list:
            if seq_group.is_prompt:
                return False
149

150
        # TODO: Add support for other attn backends
151
        if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
152
            return False
153

154
155
156
        # TODO: Add support for LORA
        if self.lora_config:
            return False
157

158
        # TODO: Add soft-tuning prompt adapter support
159
        return not self.prompt_adapter_config
160

161
162
163
164
    def set_indices_of_seq_with_bonus_tokens(self,
                                             indices_of_seq_with_bonus_tokens):
        self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens

165
166
167
    @torch.inference_mode()
    def execute_model(
        self,
168
        model_input: ModelRunnerInputBase,
169
        kv_caches: List[torch.Tensor],
170
        previous_hidden_states: Optional[torch.Tensor] = None,
171
        intermediate_tensors: Optional[IntermediateTensors] = None,
172
        num_steps: int = 1,
173
        **kwargs,
174
    ) -> Optional[List[SamplerOutput]]:
175
        """Executes num_steps forward passes with advacement of input tensors
176
        on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
177

178
179
        Optimizations used:
            1. Input tensors are updated on the GPU directly
180
            2. Skips GPU=>CPU serialization of sampler outputs (we don't need
181
182
183
184
                them since we do batch expansion later that uses GPU outputs)
            3. Reuses sampling tensors (since we run only decodes and they have
                a repeating sampling logic)
        """
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        # When num_steps == 1, we execute the fallback here for the GPU
        # advance_step, which runs prepare_inputs on CPU and for each spec
        # iteration invokes this function only once
        # (Look at multi-step-worker code)
        is_fallback = num_steps == 1
        if not is_fallback:
            # Since we do not broadcast data inside execute_model anymore,
            # we need to figure out the best way to support TP > 1 in this
            # case, because we will at least need to broadcast the sampled
            # tokens to all workers.
            if not self.is_driver_worker:
                raise ValueError("TP1DraftModelRunner only supports TP=1.")

            # Sanity
            if self.lora_config is not None:
                raise ValueError("TP1DraftModelRunner has no support for LORA")
            if self.prompt_adapter_config is not None:
                raise ValueError("TP1DraftModelRunner has no support for "
                                 "prompt_adapter_config")
205
206
207
            if model_input.inputs_embeds is not None:
                raise ValueError("TP1DraftModelRunner has no support for "
                                 "inputs_embeds")
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            if model_input.multi_modal_kwargs:
                raise ValueError(
                    "TP1DraftModelRunner has no support for multi_modal_kwargs"
                )
        else:
            if self.lora_config:
                assert model_input.lora_requests is not None
                assert model_input.lora_mapping is not None
                self.set_active_loras(model_input.lora_requests,
                                      model_input.lora_mapping)

            if self.prompt_adapter_config:
                assert model_input.prompt_adapter_requests is not None
                assert model_input.prompt_adapter_mapping is not None
                self.set_active_prompt_adapters(
                    model_input.prompt_adapter_requests,
                    model_input.prompt_adapter_mapping)

226
            self.attn_state.begin_forward(model_input)
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        # Detect exec mode
        assert model_input.attn_metadata is not None
        use_cuda_graph = False
        if model_input.attn_metadata.num_prefills > 0:
            # In this case, execute_model(..) was called directly
            if num_steps > 1:
                raise ValueError(
                    "execute_model(..) of draft_model_runner can be called "
                    "directly only with a single-step prefill")
        else:
            # We can skip CPU samples for spec token generation.
            # (We do allow CPU samples for num_steps == 1 to support the
            # fallback case, where supports_gpu_multi_step(..) does not pass)
            model_input.sampling_metadata.skip_sampler_cpu_output = (
                not is_fallback)

            # Attn attr defines if we use cuda graphs
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph

        # Get model
        if use_cuda_graph:
249
250
251
252
253
254
255
256
257
258
            if model_input.inputs_embeds is None:
                graph_batch_size = model_input.input_tokens.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, False)])
            else:
                graph_batch_size = model_input.inputs_embeds.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, True)])
259
260
261
262
263
264
265
266
267
268
269
270
271

            if previous_hidden_states is not None:
                hidden_states = torch.cat([
                    previous_hidden_states,
                    torch.empty([
                        graph_batch_size - previous_hidden_states.shape[0],
                        *previous_hidden_states.shape[1:]
                    ],
                                dtype=previous_hidden_states.dtype,
                                device=previous_hidden_states.device)
                ])
            else:
                hidden_states = None
272
273
        else:
            model_executable = self.model
274
            hidden_states = previous_hidden_states
275

276
277
278
        outputs: List[SamplerOutput] = []
        for step in range(num_steps):
            multi_modal_kwargs = model_input.multi_modal_kwargs or {}
279

280
            model_execute_kwargs = {"previous_hidden_states": hidden_states} \
281
282
                if previous_hidden_states is not None else {}

283
            compute_logits_kwargs = {}
284
            # Run model
285
286
287
288
289
290
            if hasattr(self.model.config, "num_nextn_predict_layers"):
                # for DeepSeek MTP only to use the corresponding layer for
                # each step
                spec_step_idx = kwargs.get("spec_step_idx", step)
                model_execute_kwargs["spec_step_idx"] = spec_step_idx
                compute_logits_kwargs["spec_step_idx"] = spec_step_idx
291
292
            with set_forward_context(model_input.attn_metadata,
                                     self.vllm_config):
293
294
                hidden_states = model_executable(
                    input_ids=model_input.input_tokens,
295
                    inputs_embeds=None,
296
297
                    positions=model_input.input_positions,
                    intermediate_tensors=intermediate_tensors,
298
299
300
301
                    **MultiModalKwargs.as_kwargs(
                        multi_modal_kwargs,
                        device=self.device,
                    ),
302
                    **model_execute_kwargs,
303
                )
304
305
306

            # Compute the logits.
            logits = self.model.compute_logits(hidden_states,
307
308
309
310
                                               model_input.sampling_metadata,
                                               **compute_logits_kwargs)
            if not self.is_driver_worker:
                return []
311
            # Sample the next token.
312
            output = self.model_runner.sampler(
313
314
315
316
317
                logits=logits,
                sampling_metadata=model_input.sampling_metadata,
            )
            outputs.append(output)

318
            if self.return_hidden_states and is_fallback:
319
320
321
322
323
324
                if use_cuda_graph:
                    indices = model_input.sampling_metadata\
                      .selected_token_indices
                    output.hidden_states = hidden_states[:len(indices)]
                else:
                    output.hidden_states = hidden_states
325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            if model_input.attn_metadata.num_prefills == 0 \
                and self.indices_of_seq_with_bonus_tokens is not None:
                assert output.sampled_token_ids is not None
                # output.sampled_token_ids should be of shape (num_seqs, 1)
                nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
                assert num_tokens_per_seq == 1
                count = 0
                for i in range(nums_seqs):
                    bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
                        count]
                    if i != bonus_seq_idx:
                        # The following might cause a cpu->gpu sync
                        # However, the performance impact is negligible as we
                        # benchmarked on H100.
                        output.sampled_token_ids[
                            i, :] = model_input.input_tokens[bonus_seq_idx]
                    else:
                        count += 1
344

345
            # Prepare inputs for the next step
346
            if step != num_steps - 1:
347
                model_input = self._gpu_advance_step(model_input, outputs[-1])
348
349

        return outputs