test_eagle.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from unittest import mock

zhuwenwen's avatar
zhuwenwen committed
6
import os
7
8
9
import pytest
import torch

10
from tests.utils import get_attn_backend_list_based_on_platform
11
12
13
14
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
15
    try_get_attention_backend,
16
17
)
from vllm.config import (
18
    AttentionConfig,
19
20
21
22
23
24
25
26
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    SpeculativeConfig,
    VllmConfig,
)
27
from vllm.config.load import LoadConfig
28
from vllm.model_executor.models.llama import LlamaForCausalLM
29
from vllm.platforms import current_platform
30
from vllm.v1.attention.backends.registry import AttentionBackendEnum
31
from vllm.v1.spec_decode.eagle import EagleProposer
32

33
34
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
zhuwenwen's avatar
zhuwenwen committed
35
from ...utils import models_path_prefix
36

zhuwenwen's avatar
zhuwenwen committed
37
38
39
model_dir = os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct")
eagle_dir = os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3.1-Instruct-8B")
eagle3_dir = os.path.join(models_path_prefix, "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
40
41


42
43
44
def _create_proposer(
    method: str,
    num_speculative_tokens: int,
45
    attention_backend: str | None = None,
46
    speculative_token_tree: list[tuple[int, ...]] | None = None,
47
) -> EagleProposer:
48
    model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
49
50
51
52

    # Choose model directory based on method
    draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

53
54
55
56
57
    spec_token_tree_str = None
    if speculative_token_tree is not None:
        assert num_speculative_tokens == len(speculative_token_tree)
        spec_token_tree_str = str(speculative_token_tree)

58
59
60
61
62
    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        model=draft_model_dir,
        method=method,
63
64
        num_speculative_tokens=num_speculative_tokens,
        speculative_token_tree=spec_token_tree_str,
65
66
    )

67
68
69
70
71
72
73
    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(),
        speculative_config=speculative_config,
        device_config=DeviceConfig(device=current_platform.device_type),
        parallel_config=ParallelConfig(),
        load_config=LoadConfig(),
74
75
76
77
        scheduler_config=SchedulerConfig(
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
        ),
78
        attention_config=AttentionConfig(backend=attention_backend),
79
    )
80

81
    return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
82
83


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def test_prepare_next_token_ids():
    """
    Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
    Each will produce a device tensor of next_token_ids, taking as input
    either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
    or the CPU python list[list[int]] with the rejected tokens removed.
    """
    device = torch.device(current_platform.device_type)

    num_requests = 4
    num_speculative_tokens = 4
    batch_spec = BatchSpec(
        seq_lens=[num_speculative_tokens + 1] * num_requests,
        query_lens=[num_speculative_tokens + 1] * num_requests,
    )

100
    req_ids = [f"req_{i + 1}" for i in range(num_requests)]
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    mock_input_batch = mock.MagicMock(spec=InputBatch)
    mock_input_batch.req_ids = req_ids
    mock_input_batch.num_reqs = num_requests
    mock_input_batch.vocab_size = 100

    mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
    mock_requests = {}
    for req_id in req_ids:
        mock_request = mock.MagicMock(spec=CachedRequestState)
        # Each request will have a backup next token id of 10, 20, 30, 40
        mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
        mock_request.num_computed_tokens = 0
        mock_requests[req_id] = mock_request

115
116
117
118
    # explicitly discard the last request
    discarded_req_mask = torch.tensor(
        [False, False, False, True], dtype=torch.bool, device=device
    )
119
120
121
122
    sampled_token_ids = [
        [0, 1, -1, -1, -1],  # 1 accepted, 3 rejected, "1" sampled
        [0, 1, 2, 3, 4],  # all accepted, "4" sampled
        [-1, -1, -1, -1, -1],  # sampling skipped, use backup token "30"
123
        [0, 1, 2, -1, -1],  # explicitly discarded, sampling should be ignored
124
    ]
125
126
127
    sampled_token_ids_tensor = torch.tensor(
        sampled_token_ids, dtype=torch.int32, device=device
    )
128
    sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
129
130
131
    for i in range(len(sampled_token_ids_cpu)):
        if discarded_req_mask[i]:
            sampled_token_ids_cpu[i] = []
132
133

    expected_next_token_ids_cpu = [1, 4, 30, 40]
134
135
136
    expected_next_token_ids_tensor = torch.tensor(
        expected_next_token_ids_cpu, dtype=torch.int32, device=device
    )
137
138
139
140

    proposer = _create_proposer("eagle", num_speculative_tokens)

    next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
141
142
143
144
145
        sampled_token_ids_cpu,
        mock_requests,
        mock_input_batch,
        mock_num_scheduled_tokens,
    )
146
147
148
149
150
151
152
153
154

    assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

155
156
157
    expected_valid_sampled_tokens_count = torch.tensor(
        [2, 5, 0, 0], dtype=torch.int32, device=device
    )
158

159
    next_token_ids_from_padded, valid_sampled_tokens_count = (
160
        proposer.prepare_next_token_ids_padded(
161
162
163
164
            common_attn_metadata,
            sampled_token_ids_tensor,
            mock_requests,
            mock_input_batch,
165
            discarded_req_mask,
166
167
        )
    )
168

169
170
    assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
    assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
171
172


173
174
175
176
177
178
179
180
181
182
def test_prepare_inputs():
    """
    cu_target_query_lens: [0, a, a + b, a + b + c]
    num_rejected_tokens: [n1, n2, n3]
    num_tokens_per_req: [a - n1, b - n2, c - n3]
    cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    token_indices: [0, 1, ..., a - n1 - 1,
                    a, a + 1, ..., a + b - n2 - 1,
                    a + b, a + b + 1, ..., a + b + c - n3 - 1]
    """
183
    device = torch.device(current_platform.device_type)
184

185
    # q1 = 4, q2 = 7, q3 = 5
186
187
    # n1 = 1, n2 = 3, n3 = 2

188
189
190
191
192
193
194
195
196
197
    batch_spec = BatchSpec(
        seq_lens=[4, 7, 5],
        query_lens=[4, 7, 5],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
198

199
200
201
202
203
204
205
206
207
208
209
    # If there are `k` sampled tokens, then `k-1` tokens are draft tokens
    # from the previous iteration, and the last token is the bonus token sampled
    # from the base model.
    num_draft_tokens = [3, 6, 4]  # one less than query_lens
    # num rejected tokens is [1, 3, 2]
    ACCEPT_TOKEN = 0
    BONUS_TOKEN = 1
    REJECT_TOKEN = -1
    sampled_token_ids = [
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
        [
210
211
212
213
214
215
216
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            BONUS_TOKEN,
217
        ],
218
219
220
221
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
    ]
    sampled_token_ids = [
        [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
222
    ]
223
224
225
226
227

    # Expected calculations:
    # query_len_per_req = [4, 7, 5]
    # num_tokens_per_req = [3, 4, 3]  (after subtracting rejected tokens)
    # Expected cumulative counts: [0, 3, 7, 10]
228
229
230
    expected_cu_num_tokens = torch.tensor(
        [0, 3, 7, 10], dtype=torch.int32, device=device
    )
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    # Expected token indices (mapped from original positions):
    # First request: indices 0, 1, 2      (keeping first 3 from positions 0-3)
    # Second request: indices 4, 5, 6, 7  (keeping first 4 from positions 4-10)
    # Third request: indices 11, 12, 13   (keeping first 3 from positions 11-15)
    expected_token_indices = torch.tensor(
        [
            0,
            1,
            2,  # First request: 3 tokens (4-1)
            4,
            5,
            6,
            7,  # Second request: 4 tokens (7-3)
            11,
            12,
247
            13,  # Third request: 3 tokens (5-2)
248
249
        ],
        dtype=torch.int32,
250
251
        device=device,
    )
252
    proposer = _create_proposer("eagle", 1)
253

254
    updated_metadata, token_indices = proposer.prepare_inputs(
255
256
        common_attn_metadata, sampled_token_ids, num_draft_tokens
    )
257

258
    assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
259
260
261
262
    assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
    assert torch.equal(token_indices, expected_token_indices)


263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def test_prepare_inputs_padded():
    """
    Input scenario is 3 requests with num_speculative_tokens == 2 and:
    - Request 1: query_len = 3, rejected = 1
    - Request 2: query_len = 3, rejected = 0
    - Request 3: query_len = 3, rejected = 2

    Expected outputs:
    token_indices_to_sample: [1, 5, 6]
    Reason: After accounting for rejections, these are the valid token positions
            from the original indices to sample from.
    """

    device = torch.device(current_platform.device_type)

278
279
280
    expected_token_indices_to_sample = torch.tensor(
        [1, 5, 6], dtype=torch.int32, device=device
    )
281
282
283
284
285
286
287
288
289
290
291
292
293
294

    num_speculative_tokens = 2
    batch_spec = BatchSpec(
        seq_lens=[3, 3, 3],
        query_lens=[3, 3, 3],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

    # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
295
296
297
    expected_query_start_loc = torch.tensor(
        [0, 3, 6, 9], dtype=torch.int32, device=device
    )
298
299
300
301
302
303
304
305
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids=[[0] * num_speculative_tokens] * 3,
        device=device,
    )

    # num_rejected_tokens = [1, 0, 2]
    # num_draft_tokens = [2, 2, 2]
    # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
306
307
308
    valid_sampled_tokens_count = torch.tensor(
        [2, 3, 1], dtype=torch.int32, device=device
    )
309
310
311

    proposer = _create_proposer("eagle", num_speculative_tokens)

312
313
314
315
    output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
        proposer.prepare_inputs_padded(
            common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
        )
316
    )
317

318
319
320
321
    # Verify num_rejected_tokens_gpu is calculated correctly
    expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
    assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)

322
    assert output_metadata.max_query_len == 3
323
324
    assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
    assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
325
326


327
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
328
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
329
330
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
331
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
332
333
334
335
336
337
338
339
340
341
342
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_load_model(
    mock_get_model,
    mock_get_layers,
    mock_get_pp_group,
    method,
    attn_backend,
    pp_size,
    use_distinct_embed_tokens,
343
    use_distinct_lm_head,
344
345
346
347
348
349
350
    monkeypatch,
):
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
351

352
    if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
353
354
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

355
    # Setup draft model mock
356
    mock_model = mock.MagicMock()
357
358
    mock_model.model = mock.MagicMock()
    mock_model.has_own_embed_tokens = use_distinct_embed_tokens
359
    if use_distinct_embed_tokens:
360
361
362
363
        mock_model.model.embed_tokens = mock.MagicMock()
    mock_model.has_own_lm_head = use_distinct_lm_head
    if use_distinct_lm_head:
        mock_model.lm_head = mock.MagicMock()
364

365
    mock_get_model.return_value = mock_model
366
367
368
369

    # Setup mocks for attention layers
    target_attn_layers = {
        "target_attn_1": mock.MagicMock(),
370
        "target_attn_2": mock.MagicMock(),
371
    }
372
    target_indx_layers: dict[str, mock.MagicMock] = {}
373
    # Draft model has one extra attention layer compared to target model
374
    all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
375

376
    all_indx_layers: dict[str, mock.MagicMock] = {}
377

378
379
    all_indx_layers: dict[str, mock.MagicMock] = {}

380
    # Make mock_get_layers return different values for each call
381
    mock_get_layers.side_effect = [
382
383
384
385
        target_attn_layers,
        target_indx_layers,
        all_attn_layers,
        all_indx_layers,
386
    ]
387

388
389
    # Setup mock for pp group to return the appropriate value for world size
    mock_pp_group = mock.MagicMock()
390
    mock_pp_group.world_size = pp_size
391
392
    mock_get_pp_group.return_value = mock_pp_group

393
    # Set up the target model mock with a custom class so that
394
395
396
397
398
399
400
    # isinstance() checks match the expected type.
    class _TargetModelStub(LlamaForCausalLM):
        model: mock.MagicMock
        lm_head: mock.MagicMock

    target_model = mock.create_autospec(_TargetModelStub, instance=True)
    target_model.model = mock.MagicMock()
401
402
    target_model.lm_head = mock.MagicMock()
    target_model.model.embed_tokens = mock.MagicMock()
403

404
405
406
    from vllm.model_executor.models import SupportsMultiModal

    assert not isinstance(target_model, SupportsMultiModal)
407
408

    # Create proposer using the helper function
409
410
411
    proposer = _create_proposer(
        method, num_speculative_tokens=8, attention_backend=attn_backend
    )
412
413
414
415
416

    # Call the method under test
    proposer.load_model(target_model)

    # Verify common interactions
417
    mock_get_model.assert_called_once()
418

419
420
421
422
423
    # Verify that the lm head is set correctly
    if use_distinct_lm_head:
        assert proposer.model.lm_head is not target_model.lm_head
    else:
        assert proposer.model.lm_head is target_model.lm_head
424
425
426
427

    # Verify that the embed tokens are set correctly
    # If pp_size is > 1, the embed tokens should be distinct
    if pp_size > 1 or use_distinct_embed_tokens:
428
        assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
429
    else:
430
        assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
431
432


433
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
434
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
435
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
436
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
437
438
439
440
441
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
442

443
444
445
446
447
    if attn_backend == "TREE_ATTN":
        pytest.skip(
            "TREE_ATTN is tested separately in test_propose_tree"
            "because it requires special input mocking."
        )
448

449
    if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
450
451
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

452
    # Use GPU device
453
    device = torch.device(current_platform.device_type)
454
455
456
457
458
459
460

    # Setup test parameters
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
461
    seq_lens = [seq_len_1, seq_len_2]
462
463

    # Create proposer first so we can use its actual hidden_size
464
465
466
    proposer = _create_proposer(
        "eagle", num_speculative_tokens, attention_backend=attn_backend
    )
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    # Get the hidden_size from the proposer to ensure consistency
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            logits[i, token_id] = 100.0
        return logits

    # We mock a model that returns deterministic logits
    # Sequence 1: 42, 43, 44, ...
    # Sequence 2: 60, 61, 62, ...
    base_token_ids = [42, 60]

    # Skip loading the model and replace it with a mock directly
    # Create the mock model with deterministic outputs
    model_mock = mock.MagicMock()

    # Setup for model forward calls
    forward_returns = []
    for i in range(num_speculative_tokens):
        if i == 0:
            # First call uses all tokens
            h_logits = torch.zeros(total_tokens, hidden_size, device=device)
            h_states = torch.zeros(total_tokens, hidden_size, device=device)
        else:
            # Subsequent calls use batch_size tokens
            h_logits = torch.zeros(batch_size, hidden_size, device=device)
            h_states = torch.zeros(batch_size, hidden_size, device=device)
        forward_returns.append((h_logits, h_states))

    # For single token case, we only need the first item;
    # for multi-token, we need the sequence
    if num_speculative_tokens == 1:
        model_mock.return_value = forward_returns[0]
    else:
        model_mock.side_effect = forward_returns

    # Setup for compute_logits calls
    logits_returns = []
    for i in range(num_speculative_tokens):
        # For each call, increment the base token IDs
        current_tokens = [base_id + i for base_id in base_token_ids]
        logits_returns.append(create_deterministic_logits(current_tokens))

    if num_speculative_tokens == 1:
        model_mock.compute_logits.return_value = logits_returns[0]
    else:
        model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

521
522
523
    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

524
    # Create input tensors
525
526
527
528
529
530
531
532
533
534
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
535

536
537
538
539
540
541
542
543
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
544
545
    sampling_metadata = mock.MagicMock()

546
    if attn_backend == "FLASH_ATTN":
547
548
549
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.FLASH_ATTN
        )
550
    elif attn_backend == "TRITON_ATTN":
551
552
553
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TRITON_ATTN
        )
554
    elif attn_backend == "TREE_ATTN":
555
556
557
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TREE_ATTN
        )
558
559
560
561
    elif attn_backend == "ROCM_AITER_FA":
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.ROCM_AITER_FA
        )
562
563
564
    else:
        raise ValueError(f"Unsupported attention backend: {attn_backend}")

565
566
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
567
        layer_names=proposer.attn_layer_names,
568
569
570
571
572
573
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building
    proposer.runner = mock.MagicMock()
574
    proposer.runner.attn_groups.append([mock.MagicMock()])
575
576
577
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
578
    proposer._get_attention_metadata_builder = mock.MagicMock(
579
580
        return_value=attn_metadata_builder
    )
581

582
583
584
585
586
587
588
589
590
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
591
592
593
594
595
596
597
598

    assert result.shape == (batch_size, num_speculative_tokens)

    # Create expected tokens based on our token pattern
    if num_speculative_tokens == 1:
        # Example for num_speculative_tokens=1:
        # [[42], [60]]
        expected_tokens = torch.tensor(
599
600
            [[base_token_ids[0]], [base_token_ids[1]]], device=device
        )
601
602
603
    else:
        # Example for num_speculative_tokens=3:
        # [[42, 43, 44], [60, 61, 62]]
604
605
606
        expected_tokens = torch.zeros(
            (batch_size, num_speculative_tokens), dtype=torch.int64, device=device
        )
607
608
609
610
611
612
        for i in range(batch_size):
            for j in range(num_speculative_tokens):
                expected_tokens[i, j] = base_token_ids[i] + j

    # Verify all tokens match our expectations
    assert torch.equal(result, expected_tokens)
613
614
615
616
617


@pytest.mark.parametrize(
    "spec_token_tree",
    [
618
619
620
621
622
623
        [(0,)],  # A single token
        [(0,), (0, 0), (0, 0, 0)],  # Chain
        [(0,), (1,), (2,)],  # Parallel
        [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)],  # Tree
    ],
)
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def test_propose_tree(spec_token_tree):
    # Get GPU device.
    device = torch.device(current_platform.device_type)

    # Setup test parameters.
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
    seq_lens = [seq_len_1, seq_len_2]
    num_speculative_tokens = len(spec_token_tree)

    # Create proposer first so we can use its actual hidden_size.
638
    proposer = _create_proposer(
639
640
641
        "eagle",
        num_speculative_tokens,
        speculative_token_tree=spec_token_tree,
642
    )
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    # Get the hidden_size from the proposer to ensure consistency.
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids, k: int):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            # Assign decreasing values to the k, consecutive, tokens.
            for j in range(k):
                logits[i, token_id + j] = 100.0 - j
        return logits

    # Mock a model that returns deterministic logits.
    base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)

    # Skip loading the model and replace it with a mock that returns
    # deterministic outputs.
    model_mock = mock.MagicMock()

    # Mock the model forward calls.
663
664
665
666
667
668
    forward_returns = [
        (
            torch.zeros(total_tokens, hidden_size, device=device),
            torch.zeros(total_tokens, hidden_size, device=device),
        )
    ]
669
    for cu_num_drafts in proposer.cu_drafts_per_level:
670
671
        h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
        h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
672
673
674
675
        forward_returns.append((h_logits, h_states))
    model_mock.side_effect = forward_returns

    # Mock the compute_logits calls.
676
677
678
    cu_num_drafts_tensor = torch.tensor(
        [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
    )
679
680
681
    logits_returns = []
    for level, num_children in enumerate(proposer.child_drafts_per_level):
        token_ids = base_token_ids + cu_num_drafts_tensor[level]
682
        level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
683
684
685
        level_logits = []
        for i in range(level_num_drafts // num_children):
            level_logits.append(
686
687
                create_deterministic_logits(token_ids + i * num_children, num_children)
            )
688
689
690
691
692
693
694
695
696
697
        logits_returns.append(torch.stack(level_logits, dim=1))
    model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

    # Get the tree attention metadata builder.
698
699
700
    attn_metadata_builder_cls, _ = try_get_attention_backend(
        AttentionBackendEnum.TREE_ATTN
    )
701
702
703
704
705
706
707
708
709
710
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
        layer_names=proposer.attn_layer_names,
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building.
    proposer.runner = mock.MagicMock()
    proposer.runner.attn_groups.append([mock.MagicMock()])
711
712
713
714
    proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
715
    proposer._get_attention_metadata_builder = mock.MagicMock(
716
717
        return_value=attn_metadata_builder
    )
718
719

    # Setup inputs for the proposer.
720
721
722
723
724
725
726
727
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
728
729
730
731
732
733
734
735
736
737
738
739
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )
    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
    sampling_metadata = mock.MagicMock()

    # Propose draft tokens.
740
741
742
743
744
745
746
747
748
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
749
750
751
752
753
    assert result.shape == (batch_size, num_speculative_tokens)

    # The tokens are expected to be consecutive integers starting
    # from the base token IDs.
    expected_tokens = base_token_ids[:, None] + torch.arange(
754
755
        num_speculative_tokens, dtype=torch.int64, device=device
    )
756
757
758

    # Verify that the draft tokens match our expectations.
    assert torch.equal(result, expected_tokens)