eagle.py 20.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import torch
import torch.nn as nn

6
7
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
8
                         get_layers_from_vllm_config)
9
from vllm.distributed.parallel_state import get_pp_group
10
from vllm.forward_context import set_forward_context
11
from vllm.logger import init_logger
12
from vllm.model_executor.model_loader import get_model
13
from vllm.model_executor.models import supports_multimodal
14
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
Jiayi Yao's avatar
Jiayi Yao committed
15
16
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
                                                   FlashAttentionMetadata)
17
from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDecodeMetadata
18
from vllm.v1.kv_cache_interface import KVCacheConfig
19
from vllm.v1.sample.metadata import SamplingMetadata
Jiayi Yao's avatar
Jiayi Yao committed
20
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
21

22
23
logger = init_logger(__name__)

24
25
PADDING_SLOT_ID = -1

26
27
28
29
30
31
32

class EagleProposer:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
33
        runner=None,
34
35
    ):
        self.vllm_config = vllm_config
36
37
38
        self.speculative_config = vllm_config.speculative_config
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
39

Jiayi Yao's avatar
Jiayi Yao committed
40
41
        self.runner = runner

42
        self.dtype = vllm_config.model_config.dtype
43
        self.max_model_len = vllm_config.model_config.max_model_len
44
        self.block_size = vllm_config.cache_config.block_size
45
46
47
48
49
50
51
52
        self.num_speculative_tokens = (
            self.speculative_config.num_speculative_tokens)
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
53

54
55
56
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE and
                               not self.vllm_config.model_config.enforce_eager)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.cudagraph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # persistent buffers for cuda graph
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=device)
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=device)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=device)
72
73
74
75
76
77
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
        self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
                                   1,
                                   device=device,
                                   dtype=torch.int32)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [num_tokens]
        target_slot_mapping: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        # [batch_size + 1] starting with 0
        cu_num_tokens: torch.Tensor,
        # [batch_size, max_num_blocks_per_req]
        block_table: torch.Tensor,
95
96
97
        # [batch_size]
        num_rejected_tokens_tuple: tuple[list[int], torch.Tensor],
        sampling_metadata: SamplingMetadata
98
    ) -> torch.Tensor:
99
100
101
102
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
        last_token_indices = cu_num_tokens[1:] - 1

103
104
105
106
107
108
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

109
110
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
111
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
112
113
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
114
        self.input_ids[last_token_indices] = next_token_ids
115

116
117
118
        # FA requires seq_len to have dtype int32.
        seq_lens = (target_positions[last_token_indices] + 1).int()

Jiayi Yao's avatar
Jiayi Yao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if self.method in ["eagle", "eagle3"]:
            # FIXME(woosuk): The below two ops cause synchronization. Optimize.
            max_seq_len = seq_lens.max().item()
            max_num_tokens = (cu_num_tokens[1:] -
                              cu_num_tokens[:-1]).max().item()
            attn_metadata = FlashAttentionMetadata(
                num_actual_tokens=num_tokens,
                max_query_len=max_num_tokens,
                query_start_loc=cu_num_tokens,
                max_seq_len=max_seq_len,
                seq_lens=seq_lens,
                block_table=block_table,
                slot_mapping=target_slot_mapping,
                # TODO(woosuk): Support cascade attention.
                use_cascade=False,
                common_prefix_len=0,
                cu_prefix_query_lens=None,
                prefix_kv_lens=None,
                suffix_kv_lens=None,
            )
        elif self.method == "deepseek_mtp":
            query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
            max_query_len = query_lens.max().item()

            common_attn_metadata = CommonAttentionMetadata(
144
145
146
                query_start_loc=cu_num_tokens, 
                seq_lens=seq_lens,
                num_rejected_tokens_tuple=num_rejected_tokens_tuple)
Jiayi Yao's avatar
Jiayi Yao committed
147
148
149
150

            assert self.runner is not None

            # FIXME: need to consider multiple kv_cache_groups
zhuwenwen's avatar
zhuwenwen committed
151
            attn_metadata = self.runner.attn_metadata_builders[0].build(
Jiayi Yao's avatar
Jiayi Yao committed
152
153
154
155
156
                num_reqs=batch_size,
                num_actual_tokens=num_tokens,
                max_query_len=max_query_len,
                common_prefix_len=0,
                common_attn_metadata=common_attn_metadata,
157
            )
Jiayi Yao's avatar
Jiayi Yao committed
158
159
160
        else:
            raise ValueError(f"Unsupported method: {self.method}")

161
162
163
164
165
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
166
167
168
169
170
171
172
        if self.use_cuda_graph and \
            num_tokens <= self.cudagraph_batch_sizes[-1]:
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
173
        self.hidden_states[:num_tokens] = target_hidden_states
174

175
        with set_forward_context(per_layer_attn_metadata,
176
177
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
Jiayi Yao's avatar
Jiayi Yao committed
178
179
180
181
            ret_hidden_states = self.model(
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
182
            )
Jiayi Yao's avatar
Jiayi Yao committed
183
184
185
186
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
187
        sample_hidden_states = last_hidden_states[last_token_indices]
188
        logits = self.model.compute_logits(sample_hidden_states, None)
189
        draft_token_ids = logits.argmax(dim=-1)
190
191
192

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
193
194
            # [batch_size, 1]
            return draft_token_ids.view(-1, 1)
195

Jiayi Yao's avatar
Jiayi Yao committed
196
197
198
199
        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

200
201
202
203
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]
204
205
206
207
208
209

        if self.method == "deepseek_mtp":
            hidden_states = last_hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]

210
211
212
213
214
        if self.use_cuda_graph and \
            batch_size <= self.cudagraph_batch_sizes[-1]:
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
215
216
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
217
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]
218
219
220
221
222
223
224
225
226
227
228

        if isinstance(attn_metadata, MLACommonMetadata):
            attn_metadata.num_decodes = batch_size
            attn_metadata.num_decode_tokens = batch_size
            attn_metadata.num_prefills = 0
            block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
            attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
                block_table_tensor=block_table,
                seq_lens=(seq_lens + 1),
            )

229
230
        for _ in range(self.num_speculative_tokens - 1):
            # Update the inputs.
231
232
233
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
234
            positions += 1
235
236
237
238
239
240
241
242
243
244
245
246
247

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

248
249
250
251
252
253
254
255
256
257
258
259
260
261
            if isinstance(attn_metadata, MLACommonMetadata):
                attn_metadata.decode.seq_lens += 1
            else:
                attn_metadata.seq_lens += 1

                # Increment the sequence lengths.
                attn_metadata.max_seq_len += 1
                # Consider max model length.
                attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                                self.max_model_len)
                    
                # For the requests that exceed the max model length, we set the
                # sequence length to 1 to minimize their overheads in attention.
                attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
262

263
            # Compute the slot mapping.
264
            block_numbers = clamped_positions // self.block_size
265
            block_ids = block_table.gather(dim=1,
266
                                        index=block_numbers.view(-1, 1))
267
268
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
269
                                        clamped_positions % self.block_size)
270
271
272
273
274
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)
275

276
277
278
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
279
            self.hidden_states[:batch_size] = hidden_states
280

281
            # Run the model.
282
            with set_forward_context(per_layer_attn_metadata,
283
284
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
285
                ret_hidden_states = self.model(
Jiayi Yao's avatar
Jiayi Yao committed
286
287
288
                    self.input_ids[:input_batch_size],
                    self.positions[:input_batch_size],
                    self.hidden_states[:input_batch_size],
289
                )
290
291
292
293
294
295
296
                if self.method == "deepseek_mtp":
                    last_hidden_states = ret_hidden_states
                    hidden_states = last_hidden_states[:batch_size]
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
                    hidden_states = hidden_states[:batch_size]

297
298
            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)
299
300

            # TODO(wenlong): get more than one token for tree attention
301
            draft_token_ids = logits.argmax(dim=-1)
302
303
304
305
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
306
        return draft_token_ids
307
308
309
310
311
312
313

    @staticmethod
    def prepare_inputs(
        # [batch_size + 1]
        cu_target_query_lens: torch.Tensor,
        # [batch_size]
        num_rejected_tokens: torch.Tensor,
314
        num_tokens: int,
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # cu_target_query_lens: [0, a, a + b, a + b + c]
        # num_rejected_tokens: [n1, n2, n3]
        # num_tokens_per_req: [a - n1, b - n2, c - n3]
        # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
        # token_indices: [0, 1, ..., a - n1 - 1,
        #                 a, a + 1, ..., a + b - n2 - 1,
        #                 a + b, a + b + 1, ..., a + b + c - n3 - 1]

        # [0, a, a + b, a + b + c] -> [a, b, c]
        query_len_per_req = (cu_target_query_lens[1:] -
                             cu_target_query_lens[:-1])
        # [a, b, c] -> [a - n1, b - n2, c - n3]
        num_tokens_per_req = query_len_per_req - num_rejected_tokens

330
331
        # [a - n1, b - n2, c - n3] ->
        # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
332
        cu_num_tokens = torch.zeros_like(cu_target_query_lens)
333
334
335
336
        torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
        token_indices = torch.empty(
            num_tokens,
            dtype=torch.int32,
337
            device=cu_target_query_lens.device,
338
339
340
        )
        batch_size = num_rejected_tokens.shape[0]
        BLOCK_SIZE = 1024
Jiayi Yao's avatar
Jiayi Yao committed
341
        prepare_eagle_input_kernel[(batch_size, )](
342
343
344
345
346
347
348
349
            token_indices,
            cu_target_query_lens,
            cu_num_tokens,
            BLOCK_SIZE=BLOCK_SIZE,
        )
        return cu_num_tokens, token_indices

    def load_model(self, target_model: nn.Module) -> None:
350
351
        draft_model_config = \
            self.vllm_config.speculative_config.draft_model_config
352
353
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, Attention).keys())
354

355
356
        self.model = get_model(vllm_config=self.vllm_config,
                               model_config=draft_model_config)
357

358
359
360
        draft_attn_layer_names = (
            get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
            target_attn_layer_names)
361
362

        self.attn_layer_names = list(draft_attn_layer_names)
363

364
        # share embed_tokens with the target model if needed
365
        if get_pp_group().world_size == 1 \
zhuwenwen's avatar
zhuwenwen committed
366
            and self.method != "deepseek_mtp" \
367
368
            and self.model.model.embed_tokens.weight.shape \
                == target_model.model.embed_tokens.weight.shape:
369
            logger.info(
370
                "Assuming the EAGLE head shares the same vocab embedding" \
371
372
                " with the target model."
            )
373
            del self.model.model.embed_tokens
374
            self.model.model.embed_tokens = target_model.model.embed_tokens
375
        else:
376
            logger.info(
377
378
                "The EAGLE head's vocab embedding will be loaded separately" \
                " from the target model."
379
380
381
382
383
384
385
            )

        # share lm_head with the target model if needed
        # some model definition do not define lm_head explicitly
        # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
        if self.vllm_config.speculative_config.method != "eagle3" and \
                hasattr(target_model, "lm_head"):
386
            logger.info("Loading EAGLE LM head weights from the target model.")
387
388
389
390
            if supports_multimodal(target_model):
                self.model.lm_head = target_model.get_language_model().lm_head
            else:
                self.model.lm_head = target_model.lm_head
391

392
393
394
395
396
397
398
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
    ) -> None:
        with set_forward_context(None, self.vllm_config,
                                 num_tokens=num_tokens):
399
            self.model(
Jiayi Yao's avatar
Jiayi Yao committed
400
401
402
                self.input_ids[:num_tokens],
                self.positions[:num_tokens],
                self.hidden_states[:num_tokens],
403
            )
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def validate_same_kv_cache_group(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
        assert len(
            set([
                kv_cache_groups[layer_name]
                for layer_name in self.attn_layer_names
            ])
        ) == 1, "All eagle layers should belong to the same kv cache group"
423
424


425
426
427
428
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

    is_greedy = sampling_metadata.temperature == -1
    temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
455
456
457
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
458
459
460
461
462
463
464
465
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs