eagle.py 32.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from typing import Any, Optional
import numpy as np

6
7
import torch
import torch.nn as nn
王敏's avatar
王敏 committed
8
import torch.nn.functional as F
9

王敏's avatar
王敏 committed
10
import vllm.envs as envs
11
12
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
13
                         get_layers_from_vllm_config)
14
from vllm.distributed.parallel_state import get_pp_group
15
from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
16
from vllm.logger import init_logger
yangql's avatar
yangql committed
17
import vllm.envs as envs
18
from vllm.model_executor.model_loader import get_model
19
from vllm.model_executor.models import supports_multimodal
20
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
zhuwenwen's avatar
zhuwenwen committed
21

22
23
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
24
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
zhuwenwen's avatar
zhuwenwen committed
25

26
from vllm.v1.kv_cache_interface import KVCacheConfig
27
from vllm.v1.sample.metadata import SamplingMetadata
Jiayi Yao's avatar
Jiayi Yao committed
28
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
王敏's avatar
王敏 committed
29
from vllm.utils import round_up
30

31
32
logger = init_logger(__name__)

33
34
PADDING_SLOT_ID = -1

35
36
37
38

class EagleProposer:

    def __init__(
39
40
41
42
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
43
44
    ):
        self.vllm_config = vllm_config
45
46
47
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
48

Jiayi Yao's avatar
Jiayi Yao committed
49
50
        self.runner = runner

51
        self.dtype = vllm_config.model_config.dtype
52
        self.max_model_len = vllm_config.model_config.max_model_len
53
        self.block_size = vllm_config.cache_config.block_size
54
55
56
57
58
59
60
61
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
62

63
64
65
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
66
67
68
        self.use_full_cuda_graph = (
            self.use_cuda_graph
            and vllm_config.compilation_config.full_cuda_graph)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
王敏's avatar
王敏 committed
84

85
86
        # attention metadata captured in full cudagraph mode
        self.attn_metadata_cudagraph = None
87
88
89
90
91
92
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
                                   1,
                                   device=device,
                                   dtype=torch.int32)
王敏's avatar
王敏 committed
93
94
95
        
        self.dp_size = vllm_config.parallel_config.data_parallel_size
        self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
96
97
        self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
        self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
98
99
100
101
        
        self.ep_sp = False
        if self.enable_expert_parallel and self.dp_size > 1 and self.attn_tp_size > 1:
            self.ep_sp = True
102
103

    def propose(
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [num_tokens]
        target_slot_mapping: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        # [batch_size + 1] starting with 0
        cu_num_tokens: torch.Tensor,
        # [batch_size, max_num_blocks_per_req]
        block_table: torch.Tensor,
        # [batch_size]
120
121
        sampling_metadata: SamplingMetadata,
        decoding: bool = False,
王敏's avatar
王敏 committed
122
    ) -> torch.Tensor:
123
124
125
126
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
        last_token_indices = cu_num_tokens[1:] - 1

127
128
129
130
131
132
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

133
134
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
135
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
136
137
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
138
        self.input_ids[last_token_indices] = next_token_ids
139

140
141
142
        # FA requires seq_len to have dtype int32.
        seq_lens = (target_positions[last_token_indices] + 1).int()

Jiayi Yao's avatar
Jiayi Yao committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        if self.method in ["eagle", "eagle3"]:
            # FIXME(woosuk): The below two ops cause synchronization. Optimize.
            max_seq_len = seq_lens.max().item()
            max_num_tokens = (cu_num_tokens[1:] -
                              cu_num_tokens[:-1]).max().item()
            attn_metadata = FlashAttentionMetadata(
                num_actual_tokens=num_tokens,
                max_query_len=max_num_tokens,
                query_start_loc=cu_num_tokens,
                max_seq_len=max_seq_len,
                seq_lens=seq_lens,
                block_table=block_table,
                slot_mapping=target_slot_mapping,
                # TODO(woosuk): Support cascade attention.
                use_cascade=False,
                common_prefix_len=0,
                cu_prefix_query_lens=None,
                prefix_kv_lens=None,
                suffix_kv_lens=None,
            )
        elif self.method == "deepseek_mtp":
            query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
            max_query_len = query_lens.max().item()

            common_attn_metadata = CommonAttentionMetadata(
168
                query_start_loc=cu_num_tokens,
169
                seq_lens=seq_lens,
170
171
172
                num_reqs=batch_size,
                num_actual_tokens=num_tokens,
                max_query_len=max_query_len,
173
174
                slot_mapping=target_slot_mapping,
                spec_layer_decoding=decoding
175
            )
Jiayi Yao's avatar
Jiayi Yao committed
176
177
178
179

            assert self.runner is not None

            # FIXME: need to consider multiple kv_cache_groups
zhuwenwen's avatar
zhuwenwen committed
180
            attn_metadata = self.runner.attn_metadata_builders[0].build(
Jiayi Yao's avatar
Jiayi Yao committed
181
                common_prefix_len=0,
182
                common_attn_metadata=common_attn_metadata
183
            )
Jiayi Yao's avatar
Jiayi Yao committed
184
185
186
        else:
            raise ValueError(f"Unsupported method: {self.method}")

187
188
189
190
191
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
192
        if self.use_cuda_graph and \
193
            num_tokens <= self.cudagraph_batch_sizes[-1]:
194
195
196
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
yangql's avatar
yangql committed
197
        
198
199
        if self.enable_dp_attention:
            num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
王敏's avatar
王敏 committed
200

201
202
        # num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        # num_input_tokens += num_pad
203
204
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
205
        self.hidden_states[:num_tokens] = target_hidden_states
206

207
        if (decoding and self.use_full_cuda_graph
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                and num_tokens <= self.cudagraph_batch_sizes[-1]):
            assert self.attn_metadata_cudagraph
            if self.method in ["eagle", "eagle3"]:
                self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                    attn_metadata.seq_lens)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.block_table[:batch_size] = (
                    attn_metadata.block_table)
            elif self.method == "deepseek_mtp":
                self.attn_metadata_cudagraph.num_actual_tokens = (
                    attn_metadata.num_actual_tokens)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.num_decodes = (
                    attn_metadata.num_decodes)
                self.attn_metadata_cudagraph.num_decode_tokens = (
                    attn_metadata.num_decode_tokens)
                self.attn_metadata_cudagraph.num_prefills = (
                    attn_metadata.num_prefills)
王敏's avatar
王敏 committed
232

233
234
                if attn_metadata.decode is not None:
                    self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
235
                            attn_metadata.decode.block_table)
236
237
238
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)

239
        with set_forward_context(per_layer_attn_metadata,
240
                                 self.vllm_config,
241
242
                                 num_tokens=num_input_tokens):
                                 #skip_cuda_graphs=not decoding):
Jiayi Yao's avatar
Jiayi Yao committed
243
244
245
246
            ret_hidden_states = self.model(
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
247
            )
Jiayi Yao's avatar
Jiayi Yao committed
248
249
250
251
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
252
        sample_hidden_states = last_hidden_states[last_token_indices]
253
254
        logits = self.model.compute_logits(sample_hidden_states, None)

王敏's avatar
王敏 committed
255
        draft_token_ids = logits.argmax(dim=-1)
256

王敏's avatar
王敏 committed
257
258
259
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_prob = logits.softmax(dim=-1, dtype=torch.float32)

260
261
        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
262
            # [batch_size, 1]
王敏's avatar
王敏 committed
263
264
265
            if envs.VLLM_REJECT_SAMPLE_OPT:
                return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1])

王敏's avatar
王敏 committed
266
            return draft_token_ids.view(-1, 1)
王敏's avatar
王敏 committed
267
268
269
        
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs_list = [draft_prob]
270

Jiayi Yao's avatar
Jiayi Yao committed
271
272
273
274
        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

275
276
277
278
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]
279
280
281
282
283
284

        if self.method == "deepseek_mtp":
            hidden_states = last_hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]

285
        if self.use_cuda_graph and \
王敏's avatar
王敏 committed
286
                batch_size <= self.cudagraph_batch_sizes[-1]:
287
288
289
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
290
291
292
293
294
295
296
            
        # dp attention need all dp rank process same number tokens
        if self.enable_dp_attention:
            input_batch_size = round_up(input_batch_size, self.attn_tp_size)
            num_pad, _ = self.get_dp_padding(input_batch_size)
            input_batch_size += num_pad

297
298
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
299
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]
300
301
302
303
304
305
306
307

        if isinstance(attn_metadata, MLACommonMetadata):
            attn_metadata.num_decodes = batch_size
            attn_metadata.num_decode_tokens = batch_size
            attn_metadata.num_prefills = 0
            block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
            attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
                block_table_tensor=block_table,
王敏's avatar
王敏 committed
308
                seq_lens=seq_lens,
309
310
            )

311
        for i in range(self.num_speculative_tokens - 1):
312
            # Update the inputs.
313
314
315
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
316
            positions += 1
317
318
319
320
321
322
323
324
325
326
327
328
329

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

330
331
332
333
334
335
336
337
338
339
            if isinstance(attn_metadata, MLACommonMetadata):
                attn_metadata.decode.seq_lens += 1
            else:
                attn_metadata.seq_lens += 1

                # Increment the sequence lengths.
                attn_metadata.max_seq_len += 1
                # Consider max model length.
                attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                                self.max_model_len)
王敏's avatar
王敏 committed
340

341
342
343
                # For the requests that exceed the max model length, we set the
                # sequence length to 1 to minimize their overheads in attention.
                attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
344

345
            # Compute the slot mapping.
346
            block_numbers = clamped_positions // self.block_size
347
            block_ids = block_table.gather(dim=1,
348
                                        index=block_numbers.view(-1, 1))
349
350
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
351
                                        clamped_positions % self.block_size)
352
353
354
355
356
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)
357

358
359
360
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
361
            self.hidden_states[:batch_size] = hidden_states
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            if (self.use_full_cuda_graph
                    and batch_size <= self.cudagraph_batch_sizes[-1]):
                assert self.attn_metadata_cudagraph
                if self.method in ["eagle", "eagle3"]:
                    self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                        attn_metadata.seq_lens)
                    self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
                        attn_metadata.slot_mapping)
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size +
                                                                    1] = (
                                                                        attn_metadata
                                                                        .
                                                                        query_start_loc
                                                                    )
                        self.attn_metadata_cudagraph.block_table[:batch_size] = (
                            attn_metadata.block_table)
                elif self.method == "deepseek_mtp":
                    self.attn_metadata_cudagraph.num_actual_tokens = (
                        attn_metadata.num_actual_tokens)
                    self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.slot_mapping)
                    self.attn_metadata_cudagraph.num_decodes = (
                        attn_metadata.num_decodes)
                    self.attn_metadata_cudagraph.num_decode_tokens = (
                        attn_metadata.num_decode_tokens)
                    self.attn_metadata_cudagraph.num_prefills = (
                        attn_metadata.num_prefills)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)
王敏's avatar
王敏 committed
393

394
395
396
397
398
399
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                            attn_metadata.query_start_loc)
                        self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                            attn_metadata.decode.block_table)

400
            # Run the model.
401
            with set_forward_context(per_layer_attn_metadata,
402
403
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
404
                ret_hidden_states = self.model(
Jiayi Yao's avatar
Jiayi Yao committed
405
406
407
                    self.input_ids[:input_batch_size],
                    self.positions[:input_batch_size],
                    self.hidden_states[:input_batch_size],
408
                )
409
410
411
412
413
414
415
                if self.method == "deepseek_mtp":
                    last_hidden_states = ret_hidden_states
                    hidden_states = last_hidden_states[:batch_size]
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
                    hidden_states = hidden_states[:batch_size]

416
417
            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)
418

王敏's avatar
王敏 committed
419
            # TODO(wenlong): get more than one token for tree attention
420
            draft_token_ids = logits.argmax(dim=-1)
421
422
            draft_token_ids_list.append(draft_token_ids)

王敏's avatar
王敏 committed
423
424
425
426
            if envs.VLLM_REJECT_SAMPLE_OPT:
                draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
                draft_probs_list.append(draft_prob)

427
428
        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
429

王敏's avatar
王敏 committed
430
431
432
433
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
            return draft_token_ids, draft_probs

王敏's avatar
王敏 committed
434
        return draft_token_ids
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    # @staticmethod
    # def prepare_inputs(
    #     # [batch_size + 1]
    #     cu_target_query_lens: torch.Tensor,
    #     # [batch_size]
    #     num_rejected_tokens: torch.Tensor,
    #     num_tokens: int,
    # ) -> tuple[torch.Tensor, torch.Tensor]:
    #     # cu_target_query_lens: [0, a, a + b, a + b + c]
    #     # num_rejected_tokens: [n1, n2, n3]
    #     # num_tokens_per_req: [a - n1, b - n2, c - n3]
    #     # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    #     # token_indices: [0, 1, ..., a - n1 - 1,
    #     #                 a, a + 1, ..., a + b - n2 - 1,
    #     #                 a + b, a + b + 1, ..., a + b + c - n3 - 1]

    #     # [0, a, a + b, a + b + c] -> [a, b, c]
    #     query_len_per_req = (cu_target_query_lens[1:] -
    #                          cu_target_query_lens[:-1])
    #     # [a, b, c] -> [a - n1, b - n2, c - n3]
    #     num_tokens_per_req = query_len_per_req - num_rejected_tokens

    #     # [a - n1, b - n2, c - n3] ->
    #     # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    #     cu_num_tokens = torch.zeros_like(cu_target_query_lens)
    #     torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
    #     token_indices = torch.empty(
    #         num_tokens,
    #         dtype=torch.int32,
    #         device=cu_target_query_lens.device,
    #     )
    #     batch_size = num_rejected_tokens.shape[0]
    #     BLOCK_SIZE = 1024
    #     prepare_eagle_input_kernel[(batch_size, )](
    #         token_indices,
    #         cu_target_query_lens,
    #         cu_num_tokens,
    #         BLOCK_SIZE=BLOCK_SIZE,
    #     )
    #     return cu_num_tokens, token_indices

477
478
    @staticmethod
    def prepare_inputs(
479
480
481
        # [batch_size + 1]
        cu_target_query_lens: torch.Tensor,
        # [batch_size]
482
        num_accepted_tokens_tensor: torch.Tensor,
483
    ) -> tuple[torch.Tensor, torch.Tensor]:
484
        cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device, dtype=torch.int32)
485
        token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
486
487
488
        return cu_num_tokens, token_indices

    def load_model(self, target_model: nn.Module) -> None:
489
490
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
491
492
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
493

494
495
496
497
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
498

499
        draft_attn_layer_names = (
500
501
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
502
503

        self.attn_layer_names = list(draft_attn_layer_names)
504

505
506
507
508
509
510
511
        if supports_multimodal(target_model):
            # handle multimodality
            self.model.config.image_token_index = (
                target_model.config.image_token_index)
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
512
        # share embed_tokens with the target model if needed
513
        if get_pp_group().world_size == 1 \
514
515
            and self.method != "deepseek_mtp" \
            and self.model.model.embed_tokens.weight.shape \
516
                == target_language_model.model.embed_tokens.weight.shape:
517
            logger.info(
518
                "Assuming the EAGLE head shares the same vocab embedding" \
519
520
                " with the target model."
            )
521
            del self.model.model.embed_tokens
522
523
            self.model.model.embed_tokens = (
                target_language_model.model.embed_tokens)
524
        else:
525
            logger.info(
526
527
                "The EAGLE head's vocab embedding will be loaded separately" \
                " from the target model."
528
529
530
531
532
533
            )

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
        if self.vllm_config.speculative_config.method != "eagle3" and \
534
                hasattr(target_language_model, "lm_head"):
535
            logger.info("Loading EAGLE LM head weights from the target model.")
536
            self.model.lm_head = target_language_model.lm_head
537

yangql's avatar
yangql committed
538
539
540
541
542
543
544
545
546
547
548
549
    def get_dp_padding(self,
                       num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        dp_rank = self.vllm_config.parallel_config.data_parallel_rank

        # For DP: Don't pad when setting enforce_eager.
        # This lets us set enforce_eager on the prefiller in a P/D setup and
        # still use CUDA graphs (enabled by this padding) on the decoder.
        #
        # TODO(tms) : There are many cases where padding is enabled for
        # prefills, causing unnecessary and excessive padding of activations.

550
551
552
        if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
            # Early exit.
            return 0, None
yangql's avatar
yangql committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

        try:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                num_tokens, dp_size, dp_rank)
            max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
            num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
                                                    dp_size,
                                                    device="cpu",
                                                    dtype=torch.int32)
            return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
        except (RuntimeError, AttributeError) as e:
            # DP group may not be initialized yet during dummy run
            # Skip padding in this case
            logger.debug(
                "Skipping DP padding in eagle get_dp_padding due to: %s", e)
            return 0, None

570
571
    @torch.inference_mode()
    def dummy_run(
572
573
574
        self,
        num_tokens: int,
        attn_metadata: Optional[dict[str, Any]] = None,
575
        num_tokens_across_dp: Optional[torch.Tensor] = None,
576
    ) -> None:
577
578
579
        if attn_metadata is not None and self.attn_metadata_cudagraph is None:
            self.attn_metadata_cudagraph = attn_metadata[
                self.attn_layer_names[0]]
yangql's avatar
yangql committed
580
581
582
        
        # Padding for DP
        num_input_tokens = num_tokens
583
584
        # num_pad, _ = self.get_dp_padding(num_tokens)
        # num_input_tokens += num_pad
yangql's avatar
yangql committed
585
        
586
587
        with set_forward_context(attn_metadata,
                                 self.vllm_config,
588
589
                                 num_tokens=num_tokens,
                                 num_tokens_across_dp=num_tokens_across_dp):
590
            self.model(
yangql's avatar
yangql committed
591
592
593
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
594
            )
595

596
597
        if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
            num_tokens = 1
598
599
600
601
            
            if self.enable_dp_attention or self.ep_sp:
                num_tokens = round_up(num_tokens, self.attn_tp_size)

602
603
            # dp attention need all dp rank process same number tokens
            if self.enable_dp_attention:
604
                num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                num_tokens += num_pad

            if not get_warming_up():
                common_attn_metadata = CommonAttentionMetadata(
                    query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
                    seq_lens=self.runner.seq_lens[:num_tokens],
                    num_reqs=num_tokens,
                    num_actual_tokens=num_tokens,
                    max_query_len=num_tokens,
                    slot_mapping=self.runner.slot_mapping[:num_tokens],
                    spec_layer_decoding=True
                )

                assert self.runner is not None

                # FIXME: need to consider multiple kv_cache_groups
                attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
                    common_attn_metadata=common_attn_metadata
                )
                for i in range(self.num_speculative_tokens - 1):
                    if self.attn_metadata_cudagraph is not None:
                        if i == 0:
                            attn_metadata_cudagraph = self.attn_metadata_cudagraph

                            attn_metadata_cudagraph.num_actual_tokens = num_tokens
                            attn_metadata_cudagraph.num_decodes = num_tokens
                            attn_metadata_cudagraph.num_decode_tokens = num_tokens
632
                            attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
633
634
635
636
                                attn_metadata.slot_mapping)

                            attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
                                    attn_metadata.decode.seq_lens)
637
                            attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
638
                                attn_metadata.query_start_loc)
639
                            attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
640
641
642
643
                                attn_metadata.decode.block_table)
                                
                    with set_forward_context(attn_metadata,
                                    self.vllm_config,
644
645
                                    num_tokens=num_tokens,
                                    num_tokens_across_dp=num_tokens_across_dp):
646
647
648
649
650
                        self.model(
                            self.input_ids[:num_tokens],
                            self.positions[:num_tokens],
                            self.hidden_states[:num_tokens],
                        )
王敏's avatar
王敏 committed
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"
670
671


672
673
674
675
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
676
677
678
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
679
680
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
702
703
704
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
705
706
707
708
709
710
711
712
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs