eagle.py 16.3 KB
Newer Older
1
2

import torch
王敏's avatar
王敏 committed
3
4
5
import torch.nn.functional as F

import vllm.envs as envs
6
7
8
9
10
11
12
from vllm.forward_context import set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
王敏's avatar
王敏 committed
13
from vllm.utils import round_up
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113


class V1ZeroEagleProposer(EagleProposer):
    def __init__(self, vllm_config, device, runner=None):
        super().__init__(vllm_config, device, runner)
        self.spec_scheduler_max_num_tokens = 0


    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [num_tokens]
        target_slot_mapping: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        # [batch_size + 1] starting with 0
        cu_num_tokens: torch.Tensor,
        # [batch_size, max_num_blocks_per_req]
        block_table: torch.Tensor,
        # [batch_size]
        sampling_metadata: SamplingMetadata,
        decoding: bool = False,
    ) -> torch.Tensor:
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
        last_token_indices = cu_num_tokens[1:] - 1

        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
                target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
        self.input_ids[last_token_indices] = next_token_ids

        # FA requires seq_len to have dtype int32.
        seq_lens = (target_positions[last_token_indices] + 1).int()

        if self.method in ["eagle", "eagle3"]:
            # FIXME(woosuk): The below two ops cause synchronization. Optimize.
            max_seq_len = seq_lens.max().item()
            max_num_tokens = (cu_num_tokens[1:] -
                              cu_num_tokens[:-1]).max().item()
            attn_metadata = FlashAttentionMetadata(
                num_actual_tokens=num_tokens,
                max_query_len=max_num_tokens,
                query_start_loc=cu_num_tokens,
                max_seq_len=max_seq_len,
                seq_lens=seq_lens,
                block_table=block_table,
                slot_mapping=target_slot_mapping,
                # TODO(woosuk): Support cascade attention.
                use_cascade=False,
                common_prefix_len=0,
                cu_prefix_query_lens=None,
                prefix_kv_lens=None,
                suffix_kv_lens=None,
            )
        elif self.method == "deepseek_mtp":
            max_query_len = self.spec_scheduler_max_num_tokens
            common_attn_metadata = CommonAttentionMetadata(
                query_start_loc=cu_num_tokens,
                seq_lens=seq_lens,
                num_reqs=batch_size,
                num_actual_tokens=num_tokens,
                max_query_len=max_query_len,
                slot_mapping=target_slot_mapping,
                spec_layer_decoding=decoding
            )

            assert self.runner is not None

            # FIXME: need to consider multiple kv_cache_groups
            attn_metadata = self.runner.attn_metadata_builders[0].build(
                common_prefix_len=0,
                common_attn_metadata=common_attn_metadata
            )
        else:
            raise ValueError(f"Unsupported method: {self.method}")

        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
        if self.use_cuda_graph and \
            num_tokens <= self.cudagraph_batch_sizes[-1]:
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
        else:
            num_input_tokens = num_tokens
王敏's avatar
王敏 committed
114

王敏's avatar
王敏 committed
115
116
        if self.enable_dp_attention:
            num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
王敏's avatar
王敏 committed
117
118
119

        # num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        # num_input_tokens += num_pad
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        # copy inputs to buffer for cudagraph
        self.positions[:num_tokens] = target_positions
        self.hidden_states[:num_tokens] = target_hidden_states

        if (decoding and self.use_full_cuda_graph
                and num_tokens <= self.cudagraph_batch_sizes[-1]):
            assert self.attn_metadata_cudagraph
            if self.method in ["eagle", "eagle3"]:
                self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                    attn_metadata.seq_lens)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.block_table[:batch_size] = (
                    attn_metadata.block_table)
            elif self.method == "deepseek_mtp":
                self.attn_metadata_cudagraph.num_actual_tokens = (
                    attn_metadata.num_actual_tokens)
                self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                    attn_metadata.query_start_loc)
                self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
                    attn_metadata.slot_mapping)
                self.attn_metadata_cudagraph.num_decodes = (
                    attn_metadata.num_decodes)
                self.attn_metadata_cudagraph.num_decode_tokens = (
                    attn_metadata.num_decode_tokens)
                self.attn_metadata_cudagraph.num_prefills = (
                    attn_metadata.num_prefills)

                if attn_metadata.decode is not None:
                    self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                            attn_metadata.decode.block_table)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)

jujl1's avatar
jujl1 committed
157
        use_ep = self.vllm_config.parallel_config.enable_expert_parallel
158
159
        with set_forward_context(per_layer_attn_metadata,
                                 self.vllm_config,
jujl1's avatar
jujl1 committed
160
161
                                 num_tokens=num_input_tokens,
                                 skip_cuda_graphs= not (decoding or use_ep)):
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            ret_hidden_states = self.model(
                self.input_ids[:num_input_tokens],
                self.positions[:num_input_tokens],
                self.hidden_states[:num_input_tokens],
            )
            if self.method == "deepseek_mtp":
                last_hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states
        sample_hidden_states = last_hidden_states[last_token_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)

        draft_token_ids = logits.argmax(dim=-1)

王敏's avatar
王敏 committed
176
177
178
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_prob = logits.softmax(dim=-1, dtype=torch.float32)

179
180
181
        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            # [batch_size, 1]
王敏's avatar
王敏 committed
182
183
184
            if envs.VLLM_REJECT_SAMPLE_OPT:
                return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1])

185
            return draft_token_ids.view(-1, 1)
王敏's avatar
王敏 committed
186
187
188
        
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs_list = [draft_prob]
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        # TODO: Currently, MTP module released by deepseek only has
        # one layer. Adapt this code to support multiple layers once
        # there's a multi-layer MTP module.

        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        positions = target_positions[last_token_indices]

        if self.method == "deepseek_mtp":
            hidden_states = last_hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]

        if self.use_cuda_graph and \
                batch_size <= self.cudagraph_batch_sizes[-1]:
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
        else:
            input_batch_size = batch_size
王敏's avatar
王敏 committed
209
210
211
212
213
214
215

        # dp attention need all dp rank process same number tokens
        if self.enable_dp_attention:
            input_batch_size = round_up(input_batch_size, self.attn_tp_size)
            num_pad, _ = self.get_dp_padding(input_batch_size)
            input_batch_size += num_pad

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        attn_metadata.num_actual_tokens = batch_size
        attn_metadata.max_query_len = 1
        attn_metadata.query_start_loc = self.arange[:batch_size + 1]

        if isinstance(attn_metadata, MLACommonMetadata):
            attn_metadata.num_decodes = batch_size
            attn_metadata.num_decode_tokens = batch_size
            attn_metadata.num_prefills = 0
            block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
            attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
                block_table_tensor=block_table,
                seq_lens=seq_lens,
            )

        for i in range(self.num_speculative_tokens - 1):
            # Update the inputs.
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
            positions += 1

            # NOTE(woosuk): We should handle the case where the draft model
            # generates tokens beyond the max model length. Since it is complex
            # to remove such requests from the batch, we keep them in the batch
            # but adjust the position ids and slot mappings to avoid the
            # out-of-range access during the model execution. The draft tokens
            # generated with this adjustment should be ignored.
            exceeds_max_model_len = positions >= self.max_model_len
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
            clamped_positions = torch.where(exceeds_max_model_len, 0,
                                            positions)

            if isinstance(attn_metadata, MLACommonMetadata):
                attn_metadata.decode.seq_lens += 1
            else:
                attn_metadata.seq_lens += 1

                # Increment the sequence lengths.
                attn_metadata.max_seq_len += 1
                # Consider max model length.
                attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
                                                self.max_model_len)

                # For the requests that exceed the max model length, we set the
                # sequence length to 1 to minimize their overheads in attention.
                attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
            block_numbers = clamped_positions // self.block_size
            block_ids = block_table.gather(dim=1,
                                        index=block_numbers.view(-1, 1))
            block_ids = block_ids.view(-1)
            attn_metadata.slot_mapping = (block_ids * self.block_size +
                                        clamped_positions % self.block_size)
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
                                                    PADDING_SLOT_ID)

            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
            self.positions[:batch_size] = clamped_positions
            self.hidden_states[:batch_size] = hidden_states

            if (self.use_full_cuda_graph
                    and batch_size <= self.cudagraph_batch_sizes[-1]):
                assert self.attn_metadata_cudagraph
                if self.method in ["eagle", "eagle3"]:
                    self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
                        attn_metadata.seq_lens)
                    self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
                        attn_metadata.slot_mapping)
                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size +
                                                                    1] = (
                                                                        attn_metadata
                                                                        .
                                                                        query_start_loc
                                                                    )
                        self.attn_metadata_cudagraph.block_table[:batch_size] = (
                            attn_metadata.block_table)
                elif self.method == "deepseek_mtp":
                    self.attn_metadata_cudagraph.num_actual_tokens = (
                        attn_metadata.num_actual_tokens)
                    self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.slot_mapping)
                    self.attn_metadata_cudagraph.num_decodes = (
                        attn_metadata.num_decodes)
                    self.attn_metadata_cudagraph.num_decode_tokens = (
                        attn_metadata.num_decode_tokens)
                    self.attn_metadata_cudagraph.num_prefills = (
                        attn_metadata.num_prefills)
                    self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
                        attn_metadata.decode.seq_lens)

                    if i == 0:
                        self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
                            attn_metadata.query_start_loc)
                        self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
                            attn_metadata.decode.block_table)

            # Run the model.
            with set_forward_context(per_layer_attn_metadata,
                                     self.vllm_config,
                                     num_tokens=input_batch_size):
                ret_hidden_states = self.model(
                    self.input_ids[:input_batch_size],
                    self.positions[:input_batch_size],
                    self.hidden_states[:input_batch_size],
                )
                if self.method == "deepseek_mtp":
                    last_hidden_states = ret_hidden_states
                    hidden_states = last_hidden_states[:batch_size]
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
                    hidden_states = hidden_states[:batch_size]

            logits = self.model.compute_logits(last_hidden_states[:batch_size],
                                               None)

            # TODO(wenlong): get more than one token for tree attention
            draft_token_ids = logits.argmax(dim=-1)
            draft_token_ids_list.append(draft_token_ids)

王敏's avatar
王敏 committed
342
343
344
345
            if envs.VLLM_REJECT_SAMPLE_OPT:
                draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
                draft_probs_list.append(draft_prob)

346
347
348
        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)

王敏's avatar
王敏 committed
349
350
351
352
        if envs.VLLM_REJECT_SAMPLE_OPT:
            draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
            return draft_token_ids, draft_probs

353
        return draft_token_ids