test_eagle.py 27.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8
9
from unittest import mock

import pytest
import torch

10
from tests.utils import get_attn_backend_list_based_on_platform
11
12
13
14
from tests.v1.attention.utils import (BatchSpec, _Backend,
                                      create_common_attn_metadata,
                                      create_standard_kv_cache_spec,
                                      get_attention_backend)
15
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
16
17
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
                         VllmConfig)
18
from vllm.config.load import LoadConfig
19
from vllm.model_executor.models.llama import LlamaForCausalLM
20
from vllm.platforms import current_platform
21
from vllm.v1.spec_decode.eagle import EagleProposer
22
23
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
24
25
26
27
28
29

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


30
31
32
def _create_proposer(
    method: str,
    num_speculative_tokens: int,
33
    speculative_token_tree: Optional[list[tuple[int, ...]]] = None,
34
) -> EagleProposer:
35
    model_config = ModelConfig(model=model_dir,
36
37
                               runner="generate",
                               max_model_len=100)
38
39
40
41

    # Choose model directory based on method
    draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

42
43
44
45
46
    spec_token_tree_str = None
    if speculative_token_tree is not None:
        assert num_speculative_tokens == len(speculative_token_tree)
        spec_token_tree_str = str(speculative_token_tree)

47
48
49
50
51
    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        model=draft_model_dir,
        method=method,
52
53
        num_speculative_tokens=num_speculative_tokens,
        speculative_token_tree=spec_token_tree_str,
54
55
    )

56
57
58
59
60
61
62
63
    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(),
        speculative_config=speculative_config,
        device_config=DeviceConfig(device=current_platform.device_type),
        parallel_config=ParallelConfig(),
        load_config=LoadConfig(),
        scheduler_config=SchedulerConfig())
64

65
66
    return EagleProposer(vllm_config=vllm_config,
                         device=current_platform.device_type)
67
68


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def test_prepare_next_token_ids():
    """
    Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
    Each will produce a device tensor of next_token_ids, taking as input
    either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
    or the CPU python list[list[int]] with the rejected tokens removed.
    """
    device = torch.device(current_platform.device_type)

    num_requests = 4
    num_speculative_tokens = 4
    batch_spec = BatchSpec(
        seq_lens=[num_speculative_tokens + 1] * num_requests,
        query_lens=[num_speculative_tokens + 1] * num_requests,
    )

    req_ids = [f"req_{i+1}" for i in range(num_requests)]
    mock_input_batch = mock.MagicMock(spec=InputBatch)
    mock_input_batch.req_ids = req_ids
    mock_input_batch.num_reqs = num_requests
    mock_input_batch.vocab_size = 100

    mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
    mock_requests = {}
    for req_id in req_ids:
        mock_request = mock.MagicMock(spec=CachedRequestState)
        # Each request will have a backup next token id of 10, 20, 30, 40
        mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
        mock_request.num_computed_tokens = 0
        mock_requests[req_id] = mock_request

    sampled_token_ids = [
        [0, 1, -1, -1, -1],  # 1 accepted, 3 rejected, "1" sampled
        [0, 1, 2, 3, 4],  # all accepted, "4" sampled
        [-1, -1, -1, -1, -1],  # sampling skipped, use backup token "30"
        [-1, -1, -1, -1, -1]  # this request will be discarded
    ]
    sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
                                            dtype=torch.int32,
                                            device=device)
    sampled_token_ids_cpu = [[i for i in seq if i != -1]
                             for seq in sampled_token_ids]

    expected_next_token_ids_cpu = [1, 4, 30, 40]
    expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
                                                  dtype=torch.int32,
                                                  device=device)

    proposer = _create_proposer("eagle", num_speculative_tokens)

    next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
        sampled_token_ids_cpu, mock_requests, mock_input_batch,
        mock_num_scheduled_tokens)

    assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

    discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
    num_discarded_reqs = 1

    expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
                                                       dtype=torch.int32,
                                                       device=device)

    next_token_ids_from_padded, valid_sampled_tokens_count = \
        proposer.prepare_next_token_ids_padded(
            common_attn_metadata, sampled_token_ids_tensor, mock_requests,
            mock_input_batch, discarded_req_indices, num_discarded_reqs)

    assert torch.equal(next_token_ids_from_padded,
                       expected_next_token_ids_tensor)
    assert torch.equal(valid_sampled_tokens_count,
                       expected_valid_sampled_tokens_count)


149
150
151
152
153
154
155
156
157
158
def test_prepare_inputs():
    """
    cu_target_query_lens: [0, a, a + b, a + b + c]
    num_rejected_tokens: [n1, n2, n3]
    num_tokens_per_req: [a - n1, b - n2, c - n3]
    cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    token_indices: [0, 1, ..., a - n1 - 1,
                    a, a + 1, ..., a + b - n2 - 1,
                    a + b, a + b + 1, ..., a + b + c - n3 - 1]
    """
159
    device = torch.device(current_platform.device_type)
160

161
    # q1 = 4, q2 = 7, q3 = 5
162
163
    # n1 = 1, n2 = 3, n3 = 2

164
165
166
167
168
169
170
171
172
173
    batch_spec = BatchSpec(
        seq_lens=[4, 7, 5],
        query_lens=[4, 7, 5],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    # If there are `k` sampled tokens, then `k-1` tokens are draft tokens
    # from the previous iteration, and the last token is the bonus token sampled
    # from the base model.
    num_draft_tokens = [3, 6, 4]  # one less than query_lens
    # num rejected tokens is [1, 3, 2]
    ACCEPT_TOKEN = 0
    BONUS_TOKEN = 1
    REJECT_TOKEN = -1
    sampled_token_ids = [
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
        [
            ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
            REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
        ],
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
    ]
    sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
                         for seq in sampled_token_ids]
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    # Expected calculations:
    # query_len_per_req = [4, 7, 5]
    # num_tokens_per_req = [3, 4, 3]  (after subtracting rejected tokens)
    # Expected cumulative counts: [0, 3, 7, 10]
    expected_cu_num_tokens = torch.tensor([0, 3, 7, 10],
                                          dtype=torch.int32,
                                          device=device)

    # Expected token indices (mapped from original positions):
    # First request: indices 0, 1, 2      (keeping first 3 from positions 0-3)
    # Second request: indices 4, 5, 6, 7  (keeping first 4 from positions 4-10)
    # Third request: indices 11, 12, 13   (keeping first 3 from positions 11-15)
    expected_token_indices = torch.tensor(
        [
            0,
            1,
            2,  # First request: 3 tokens (4-1)
            4,
            5,
            6,
            7,  # Second request: 4 tokens (7-3)
            11,
            12,
            13  # Third request: 3 tokens (5-2)
        ],
        dtype=torch.int32,
        device=device)
221
    proposer = _create_proposer("eagle", 1)
222

223
    updated_metadata, token_indices = proposer.prepare_inputs(
224
        common_attn_metadata, sampled_token_ids, num_draft_tokens)
225

226
227
    assert torch.equal(updated_metadata.query_start_loc,
                       expected_cu_num_tokens)
228
229
230
231
    assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
    assert torch.equal(token_indices, expected_token_indices)


232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def test_prepare_inputs_padded():
    """
    Input scenario is 3 requests with num_speculative_tokens == 2 and:
    - Request 1: query_len = 3, rejected = 1
    - Request 2: query_len = 3, rejected = 0
    - Request 3: query_len = 3, rejected = 2

    Expected outputs:
    token_indices: [0, 1, 2,
                    3, 4, 5,
                    6, 7, 8]
    Reason: Deferred computation should not disturb the original indices.

    token_indices_to_sample: [1, 5, 6]
    Reason: After accounting for rejections, these are the valid token positions
            from the original indices to sample from.
    """

    device = torch.device(current_platform.device_type)

    expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
                                          dtype=torch.int32,
                                          device=device)
    expected_token_indices_to_sample = torch.tensor([1, 5, 6],
                                                    dtype=torch.int32,
                                                    device=device)

    num_speculative_tokens = 2
    batch_spec = BatchSpec(
        seq_lens=[3, 3, 3],
        query_lens=[3, 3, 3],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

    # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
    expected_query_start_loc = torch.tensor([0, 3, 6, 9],
                                            dtype=torch.int32,
                                            device=device)
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids=[[0] * num_speculative_tokens] * 3,
        device=device,
    )

    # num_rejected_tokens = [1, 0, 2]
    # num_draft_tokens = [2, 2, 2]
    # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
    valid_sampled_tokens_count = torch.tensor([2, 3, 1],
                                              dtype=torch.int32,
                                              device=device)

    proposer = _create_proposer("eagle", num_speculative_tokens)

    output_metadata, token_indices, token_indices_to_sample = \
        proposer.prepare_inputs_padded(
            common_attn_metadata,
            spec_decode_metadata,
            valid_sampled_tokens_count)

    assert output_metadata.max_query_len == 3
    assert torch.equal(output_metadata.query_start_loc,
                       expected_query_start_loc)
    assert torch.equal(token_indices, expected_token_indices)
    assert torch.equal(token_indices_to_sample,
                       expected_token_indices_to_sample)


303
304
305
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
                         get_attn_backend_list_based_on_platform())
306
307
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
308
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
309
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
310
311
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
312
313
314
315
316
317
318
319
320
321
322
323
324
                    attn_backend, pp_size, use_distinct_embed_tokens,
                    monkeypatch):

    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

    if (attn_backend == "TRITON_ATTN_VLLM_V1"
            and not current_platform.is_rocm()):
        pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
                    "multi-token eagle spec decode on current platform")

    if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

325
    # Setup draft model mock
326
    mock_model = mock.MagicMock()
327
328
329
330
331
332
333
    if use_distinct_embed_tokens:
        # Some models can have a different hidden size than the target model,
        # so we test that their embed_tokens doesn't get overwritten
        mock_model.model.embed_tokens.weight.shape = (131072, 2048)
    else:
        mock_model.model.embed_tokens.weight.shape = (131072, 4096)

334
    mock_get_model.return_value = mock_model
335
336
337
338
339
340
341
342
343
344
345
346
347
348

    # Setup mocks for attention layers
    target_attn_layers = {
        "target_attn_1": mock.MagicMock(),
        "target_attn_2": mock.MagicMock()
    }
    # Draft model has one extra attention layer compared to target model
    all_attn_layers = {
        **target_attn_layers, "draft_extra_attn": mock.MagicMock()
    }

    # Make mock_get_layers return different values for each call
    mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]

349
350
    # Setup mock for pp group to return the appropriate value for world size
    mock_pp_group = mock.MagicMock()
351
    mock_pp_group.world_size = pp_size
352
353
    mock_get_pp_group.return_value = mock_pp_group

354
    # Set up the target model mock with a custom class so that
355
356
357
358
359
360
361
362
    # isinstance() checks match the expected type.
    class _TargetModelStub(LlamaForCausalLM):
        model: mock.MagicMock
        lm_head: mock.MagicMock

    target_model = mock.create_autospec(_TargetModelStub, instance=True)
    target_model.model = mock.MagicMock()
    target_model.model.embed_tokens.weight.shape = (131072, 4096)
363

364
365
366
367
368
    from vllm.model_executor.models import SupportsMultiModal
    assert not isinstance(target_model, SupportsMultiModal)

    if method == "eagle":
        target_model.lm_head = mock.MagicMock()
369
370

    # Create proposer using the helper function
371
    proposer = _create_proposer(method, num_speculative_tokens=8)
372
373
374
375
376

    # Call the method under test
    proposer.load_model(target_model)

    # Verify common interactions
377
    mock_get_model.assert_called_once()
378

379
    # Verify that EAGLE models gain the lm head from the target model
380
381
    if method == "eagle":
        assert proposer.model.lm_head == target_model.lm_head
382
383
384
385
386
387

    # Verify that the embed tokens are set correctly
    # If pp_size is > 1, the embed tokens should be distinct
    if pp_size > 1 or use_distinct_embed_tokens:
        assert proposer.model.model.embed_tokens != \
            target_model.model.embed_tokens
388
    else:
389
390
        # When pp_size is 1 and the draft and target models have
        # embed_tokens of the same shape, they should be shared.
391
392
393
394
        assert proposer.model.model.embed_tokens == \
            target_model.model.embed_tokens


395
396
397
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
                         get_attn_backend_list_based_on_platform())
398
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
399
400
401
402
403
404
405
406
407
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):

    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

    if (attn_backend == "TRITON_ATTN_VLLM_V1"
            and not current_platform.is_rocm()):
        pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
                    "multi-token eagle spec decode on current platform")

408
409
410
411
    if (attn_backend == "TREE_ATTN"):
        pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
                    "because it requires special input mocking.")

412
413
414
    if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

415
    # Use GPU device
416
    device = torch.device(current_platform.device_type)
417
418
419
420
421
422
423

    # Setup test parameters
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
424
    seq_lens = [seq_len_1, seq_len_2]
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

    # Create proposer first so we can use its actual hidden_size
    proposer = _create_proposer("eagle", num_speculative_tokens)
    # Get the hidden_size from the proposer to ensure consistency
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            logits[i, token_id] = 100.0
        return logits

    # We mock a model that returns deterministic logits
    # Sequence 1: 42, 43, 44, ...
    # Sequence 2: 60, 61, 62, ...
    base_token_ids = [42, 60]

    # Skip loading the model and replace it with a mock directly
    # Create the mock model with deterministic outputs
    model_mock = mock.MagicMock()

    # Setup for model forward calls
    forward_returns = []
    for i in range(num_speculative_tokens):
        if i == 0:
            # First call uses all tokens
            h_logits = torch.zeros(total_tokens, hidden_size, device=device)
            h_states = torch.zeros(total_tokens, hidden_size, device=device)
        else:
            # Subsequent calls use batch_size tokens
            h_logits = torch.zeros(batch_size, hidden_size, device=device)
            h_states = torch.zeros(batch_size, hidden_size, device=device)
        forward_returns.append((h_logits, h_states))

    # For single token case, we only need the first item;
    # for multi-token, we need the sequence
    if num_speculative_tokens == 1:
        model_mock.return_value = forward_returns[0]
    else:
        model_mock.side_effect = forward_returns

    # Setup for compute_logits calls
    logits_returns = []
    for i in range(num_speculative_tokens):
        # For each call, increment the base token IDs
        current_tokens = [base_id + i for base_id in base_token_ids]
        logits_returns.append(create_deterministic_logits(current_tokens))

    if num_speculative_tokens == 1:
        model_mock.compute_logits.return_value = logits_returns[0]
    else:
        model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

482
483
484
    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

485
    # Create input tensors
486
487
488
489
490
491
492
493
494
495
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512

    target_token_ids = torch.randint(0,
                                     vocab_size, (total_tokens, ),
                                     device=device)
    target_positions = torch.cat([
        torch.arange(seq_len_1, device=device),
        torch.arange(seq_len_2, device=device)
    ])
    target_hidden_states = torch.randn(total_tokens,
                                       hidden_size,
                                       device=device)
    next_token_ids = torch.randint(0,
                                   vocab_size, (batch_size, ),
                                   dtype=torch.int32,
                                   device=device)
    sampling_metadata = mock.MagicMock()

513
514
515
516
517
518
519
520
521
522
523
524
    if attn_backend == "FLASH_ATTN_VLLM_V1":
        attn_metadata_builder_cls, _ = get_attention_backend(
            _Backend.FLASH_ATTN_VLLM_V1)
    elif attn_backend == "TRITON_ATTN_VLLM_V1":
        attn_metadata_builder_cls, _ = get_attention_backend(
            _Backend.TRITON_ATTN_VLLM_V1)
    elif attn_backend == "TREE_ATTN":
        attn_metadata_builder_cls, _ = get_attention_backend(
            _Backend.TREE_ATTN)
    else:
        raise ValueError(f"Unsupported attention backend: {attn_backend}")

525
526
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
527
        layer_names=proposer.attn_layer_names,
528
529
530
531
532
533
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building
    proposer.runner = mock.MagicMock()
534
    proposer.runner.attn_groups.append([mock.MagicMock()])
535
    proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
536
        attn_metadata_builder
537
538
    proposer._get_attention_metadata_builder = mock.MagicMock(
        return_value=attn_metadata_builder)
539

540
541
542
543
    result = proposer.propose(target_token_ids=target_token_ids,
                              target_positions=target_positions,
                              target_hidden_states=target_hidden_states,
                              next_token_ids=next_token_ids,
544
                              last_token_indices=None,
545
                              common_attn_metadata=common_attn_metadata,
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
                              sampling_metadata=sampling_metadata)

    assert result.shape == (batch_size, num_speculative_tokens)

    # Create expected tokens based on our token pattern
    if num_speculative_tokens == 1:
        # Example for num_speculative_tokens=1:
        # [[42], [60]]
        expected_tokens = torch.tensor(
            [[base_token_ids[0]], [base_token_ids[1]]], device=device)
    else:
        # Example for num_speculative_tokens=3:
        # [[42, 43, 44], [60, 61, 62]]
        expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
                                      dtype=torch.int64,
                                      device=device)
        for i in range(batch_size):
            for j in range(num_speculative_tokens):
                expected_tokens[i, j] = base_token_ids[i] + j

    # Verify all tokens match our expectations
    assert torch.equal(result, expected_tokens)
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662


@pytest.mark.parametrize(
    "spec_token_tree",
    [
        [(0, )],  # A single token
        [(0, ), (0, 0), (0, 0, 0)],  # Chain
        [(0, ), (1, ), (2, )],  # Parallel
        [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
         (2, 1)],  # Tree
    ])
def test_propose_tree(spec_token_tree):
    # Get GPU device.
    device = torch.device(current_platform.device_type)

    # Setup test parameters.
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
    seq_lens = [seq_len_1, seq_len_2]
    num_speculative_tokens = len(spec_token_tree)

    # Create proposer first so we can use its actual hidden_size.
    proposer = _create_proposer("eagle",
                                num_speculative_tokens,
                                speculative_token_tree=spec_token_tree)
    # Get the hidden_size from the proposer to ensure consistency.
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids, k: int):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            # Assign decreasing values to the k, consecutive, tokens.
            for j in range(k):
                logits[i, token_id + j] = 100.0 - j
        return logits

    # Mock a model that returns deterministic logits.
    base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)

    # Skip loading the model and replace it with a mock that returns
    # deterministic outputs.
    model_mock = mock.MagicMock()

    # Mock the model forward calls.
    forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
                        torch.zeros(total_tokens, hidden_size, device=device))]
    for cu_num_drafts in proposer.cu_drafts_per_level:
        h_logits = torch.zeros(batch_size * cu_num_drafts,
                               hidden_size,
                               device=device)
        h_states = torch.zeros(batch_size * cu_num_drafts,
                               hidden_size,
                               device=device)
        forward_returns.append((h_logits, h_states))
    model_mock.side_effect = forward_returns

    # Mock the compute_logits calls.
    cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
                                        dtype=torch.int32,
                                        device=device)
    logits_returns = []
    for level, num_children in enumerate(proposer.child_drafts_per_level):
        token_ids = base_token_ids + cu_num_drafts_tensor[level]
        level_num_drafts = cu_num_drafts_tensor[
            level + 1] - cu_num_drafts_tensor[level]
        level_logits = []
        for i in range(level_num_drafts // num_children):
            level_logits.append(
                create_deterministic_logits(token_ids + i * num_children,
                                            num_children))
        logits_returns.append(torch.stack(level_logits, dim=1))
    model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

    # Get the tree attention metadata builder.
    attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
        layer_names=proposer.attn_layer_names,
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building.
    proposer.runner = mock.MagicMock()
    proposer.runner.attn_groups.append([mock.MagicMock()])
663
    proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
664
        attn_metadata_builder
665
666
    proposer._get_attention_metadata_builder = mock.MagicMock(
        return_value=attn_metadata_builder)
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    # Setup inputs for the proposer.
    target_token_ids = torch.randint(0,
                                     vocab_size, (total_tokens, ),
                                     device=device)
    target_positions = torch.cat([
        torch.arange(seq_len_1, device=device),
        torch.arange(seq_len_2, device=device)
    ])
    target_hidden_states = torch.randn(total_tokens,
                                       hidden_size,
                                       device=device)
    next_token_ids = torch.randint(0,
                                   vocab_size, (batch_size, ),
                                   dtype=torch.int32,
                                   device=device)
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )
    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
    sampling_metadata = mock.MagicMock()

    # Propose draft tokens.
    result = proposer.propose(target_token_ids=target_token_ids,
                              target_positions=target_positions,
                              target_hidden_states=target_hidden_states,
                              next_token_ids=next_token_ids,
699
                              last_token_indices=None,
700
701
702
703
704
705
706
707
708
709
710
                              common_attn_metadata=common_attn_metadata,
                              sampling_metadata=sampling_metadata)
    assert result.shape == (batch_size, num_speculative_tokens)

    # The tokens are expected to be consecutive integers starting
    # from the base token IDs.
    expected_tokens = base_token_ids[:, None] + torch.arange(
        num_speculative_tokens, dtype=torch.int64, device=device)

    # Verify that the draft tokens match our expectations.
    assert torch.equal(result, expected_tokens)