eagle.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import Optional

5
import numpy as np
6
7
8
import torch
import torch.nn as nn

9
10
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
11
                         get_layers_from_vllm_config)
12
from vllm.distributed.parallel_state import get_pp_group
13
from vllm.forward_context import set_forward_context
14
from vllm.logger import init_logger
15
from vllm.model_executor.model_loader import get_model
16
from vllm.model_executor.models import supports_multimodal
17
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
18
from vllm.utils import is_pin_memory_available
19
20
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
21
from vllm.v1.kv_cache_interface import KVCacheConfig
22
23
from vllm.v1.sample.metadata import SamplingMetadata

24
25
logger = init_logger(__name__)

26
27
PADDING_SLOT_ID = -1

28
29
30
31
32
33
34

class EagleProposer:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
35
        runner=None,
36
37
    ):
        self.vllm_config = vllm_config
38
39
40
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
41

Jiayi Yao's avatar
Jiayi Yao committed
42
        self.runner = runner
43
        self.dtype = vllm_config.model_config.dtype
44
45
46
47
48
49
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
50
        self.token_arange_np = np.arange(self.max_num_tokens)
51
52
53
54
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
55

56
57
58
        self.is_multimodal_model = vllm_config.model_config \
            .is_multimodal_model

59
60
61
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
77
78
79
80
81
82
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
                                   1,
                                   device=device,
                                   dtype=torch.int32)
83

84
85
86
87
88
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)

89
90
91
92
93
94
95
96
97
98
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
99
        common_attn_metadata: CommonAttentionMetadata,
100
        sampling_metadata: SamplingMetadata,
101
        mm_embeds: Optional[list[torch.Tensor]] = None,
102
    ) -> torch.Tensor:
103
104
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
105
        last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
106

107
108
109
110
111
112
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

113
114
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
115
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
116
117
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
118
        self.input_ids[last_token_indices] = next_token_ids
119

120
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
121

122
123
124
125
126
127
        # FIXME: need to consider multiple kv_cache_groups
        attn_metadata = self.runner.attn_metadata_builders[0].build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
Jiayi Yao's avatar
Jiayi Yao committed
128

129
130
131
132
133
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
134
135
136
137
138
139
140
        if self.use_cuda_graph and \
            num_tokens <= self.cudagraph_batch_sizes[-1]:
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
141
        self.hidden_states[:num_tokens] = target_hidden_states
142
143
144
145
146
147
148
149
150
151
152
153
        if self.is_multimodal_model:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = self.model.get_input_embeddings(
                input_ids,
                multimodal_embeddings=mm_embeds or None,
            )
            self.inputs_embeds[:num_tokens] = inputs_embeds
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
        else:
            inputs_embeds = None
            input_ids = self.input_ids[:num_input_tokens]
154

155
        with set_forward_context(per_layer_attn_metadata,
156
157
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
158
            ret_hidden_states = self.model(
159
160
161
162
                input_ids=input_ids,
                positions=self.positions[:num_input_tokens],
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
163
            )
Jiayi Yao's avatar
Jiayi Yao committed
164
165
166
167
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
168
        sample_hidden_states = last_hidden_states[last_token_indices]
169
        logits = self.model.compute_logits(sample_hidden_states, None)
170
        draft_token_ids = logits.argmax(dim=-1)
171
172
173

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
174
175
            # [batch_size, 1]
            return draft_token_ids.view(-1, 1)
176

Jiayi Yao's avatar
Jiayi Yao committed
177
178
179
180
        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

181
182
183
184
185
        # Currently FlashAttention is the only backend that supports
        # multi-token eagle spec decode. This is because the code below
        # makes assumptions about attn_metadata attributes available.
        assert isinstance(attn_metadata, FlashAttentionMetadata)

186
187
188
189
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]
190
191
192
193
194
195
        hidden_states = hidden_states[last_token_indices]
        if self.use_cuda_graph and \
            batch_size <= self.cudagraph_batch_sizes[-1]:
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
196
197
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
198
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]
199
200
        for _ in range(self.num_speculative_tokens - 1):
            # Update the inputs.
201
202
203
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
204
            positions += 1
205
206
207
208
209
210
211
212
213
214
215
216
217
218

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

            # Increment the sequence lengths.
219
220
            attn_metadata.max_seq_len += 1
            attn_metadata.seq_lens += 1
221
222
223
224
225
226
227
            # Consider max model length.
            attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                            self.max_model_len)
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

228
            # Compute the slot mapping.
229
            block_numbers = clamped_positions // self.block_size
230
231
            block_ids = attn_metadata.block_table.gather(
                dim=1, index=block_numbers.view(-1, 1))
232
233
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
234
235
236
237
238
239
                                          clamped_positions % self.block_size)
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)
240

241
242
243
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
244
            self.hidden_states[:batch_size] = hidden_states
245
246
247
248
249
250
251
252
            if self.is_multimodal_model:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
                self.inputs_embeds[:batch_size] = inputs_embeds
                inputs_embeds = self.inputs_embeds[:input_batch_size]
                input_ids = None
            else:
                inputs_embeds = None
                input_ids = self.input_ids[:input_batch_size]
253

254
            # Run the model.
255
            with set_forward_context(per_layer_attn_metadata,
256
257
258
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
                last_hidden_states, hidden_states = self.model(
259
260
261
262
                    input_ids=input_ids,
                    positions=self.positions[:input_batch_size],
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
263
                )
264
265
266
            hidden_states = hidden_states[:batch_size]
            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)
267
268

            # TODO(wenlong): get more than one token for tree attention
269
            draft_token_ids = logits.argmax(dim=-1)
270
271
272
273
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
274
        return draft_token_ids
275
276

    def prepare_inputs(
277
278
        self,
        common_attn_metadata: CommonAttentionMetadata,
279
        # [batch_size]
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        num_rejected_tokens: torch.Tensor
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
        This function is used to prepare the inputs for the spec decode.
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
        #         [0, q1, q1 + q2, q1 + q2 + q3]
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
        #         [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        #  common_attn_metadata.seq_lens{_cpu}:
        #         [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
        #                  q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                  q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
            - num_rejected_tokens

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
        new_query_len_per_req = (query_start_loc_cpu[1:] -
                                 query_start_loc_cpu[:-1])
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
320
            dtype=torch.int32,
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            pin_memory=is_pin_memory_available())
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
        new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
                                                  new_num_tokens_per_req_np)
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
        token_offests = self.token_arange_np[:total_num_tokens] \
            - new_query_start_locs_expanded

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
        # Final token indices are:
        # [0, 1,                                   // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,         // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2]  // req 3
        token_indices_np = token_offests + old_query_start_locs_expanded
        token_indices = torch.from_numpy(token_indices_np).to(
            device, non_blocking=True)

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=new_query_start_loc_cpu.to(device,
                                                       non_blocking=True),
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
            num_computed_tokens_cpu=common_attn_metadata.
            num_computed_tokens_cpu,
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
366
            causal=True,
367
        )
368
369

        return spec_common_attn_metadata, token_indices
370
371

    def load_model(self, target_model: nn.Module) -> None:
372
373
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
374
375
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
376

377
378
379
380
        from vllm.compilation.backends import set_model_tag
        with set_model_tag("eagle_head"):
            self.model = get_model(vllm_config=self.vllm_config,
                                   model_config=draft_model_config)
381

382
383
384
        draft_attn_layer_names = (
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
385
386

        self.attn_layer_names = list(draft_attn_layer_names)
387

388
389
390
391
392
393
394
        if supports_multimodal(target_model):
            # handle multimodality
            self.model.config.image_token_index = (
                target_model.config.image_token_index)
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
395
        # share embed_tokens with the target model if needed
396
397
        if get_pp_group().world_size == 1 \
            and self.model.model.embed_tokens.weight.shape \
398
                == target_language_model.model.embed_tokens.weight.shape:
399
            logger.info(
400
                "Assuming the EAGLE head shares the same vocab embedding" \
401
402
                " with the target model."
            )
403
            del self.model.model.embed_tokens
404
405
            self.model.model.embed_tokens = (
                target_language_model.model.embed_tokens)
406
        else:
407
            logger.info(
408
409
                "The EAGLE head's vocab embedding will be loaded separately" \
                " from the target model."
410
411
412
413
414
415
            )

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
        if self.vllm_config.speculative_config.method != "eagle3" and \
416
                hasattr(target_language_model, "lm_head"):
417
            logger.info("Loading EAGLE LM head weights from the target model.")
418
            self.model.lm_head = target_language_model.lm_head
419

420
421
422
423
424
425
426
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
    ) -> None:
        with set_forward_context(None, self.vllm_config,
                                 num_tokens=num_tokens):
427
428
429
430
431
432
433
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

434
            self.model(
435
436
437
438
                input_ids=input_ids,
                positions=self.positions[:num_tokens],
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
439
            )
440

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"

460

461
462
463
464
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
491
492
493
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
494
495
496
497
498
499
500
501
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs