test_eagle.py 25.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8

from unittest import mock

import pytest
import torch

9
from tests.utils import get_attn_backend_list_based_on_platform
10
11
12
13
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
14
    try_get_attention_backend,
15
)
16
from vllm.attention.backends.registry import AttentionBackendEnum
17
18
19
20
21
22
23
24
25
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    SpeculativeConfig,
    VllmConfig,
)
26
from vllm.config.load import LoadConfig
27
from vllm.model_executor.models.llama import LlamaForCausalLM
28
from vllm.platforms import current_platform
29
from vllm.v1.spec_decode.eagle import EagleProposer
30
31
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
32
33
34
35
36
37

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


38
39
40
def _create_proposer(
    method: str,
    num_speculative_tokens: int,
41
    speculative_token_tree: list[tuple[int, ...]] | None = None,
42
) -> EagleProposer:
43
    model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
44
45
46
47

    # Choose model directory based on method
    draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

48
49
50
51
52
    spec_token_tree_str = None
    if speculative_token_tree is not None:
        assert num_speculative_tokens == len(speculative_token_tree)
        spec_token_tree_str = str(speculative_token_tree)

53
54
55
56
57
    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        model=draft_model_dir,
        method=method,
58
59
        num_speculative_tokens=num_speculative_tokens,
        speculative_token_tree=spec_token_tree_str,
60
61
    )

62
63
64
65
66
67
68
    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(),
        speculative_config=speculative_config,
        device_config=DeviceConfig(device=current_platform.device_type),
        parallel_config=ParallelConfig(),
        load_config=LoadConfig(),
69
70
        scheduler_config=SchedulerConfig(),
    )
71

72
    return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
73
74


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def test_prepare_next_token_ids():
    """
    Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
    Each will produce a device tensor of next_token_ids, taking as input
    either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
    or the CPU python list[list[int]] with the rejected tokens removed.
    """
    device = torch.device(current_platform.device_type)

    num_requests = 4
    num_speculative_tokens = 4
    batch_spec = BatchSpec(
        seq_lens=[num_speculative_tokens + 1] * num_requests,
        query_lens=[num_speculative_tokens + 1] * num_requests,
    )

91
    req_ids = [f"req_{i + 1}" for i in range(num_requests)]
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    mock_input_batch = mock.MagicMock(spec=InputBatch)
    mock_input_batch.req_ids = req_ids
    mock_input_batch.num_reqs = num_requests
    mock_input_batch.vocab_size = 100

    mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
    mock_requests = {}
    for req_id in req_ids:
        mock_request = mock.MagicMock(spec=CachedRequestState)
        # Each request will have a backup next token id of 10, 20, 30, 40
        mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
        mock_request.num_computed_tokens = 0
        mock_requests[req_id] = mock_request

106
107
108
109
    # explicitly discard the last request
    discarded_req_mask = torch.tensor(
        [False, False, False, True], dtype=torch.bool, device=device
    )
110
111
112
113
    sampled_token_ids = [
        [0, 1, -1, -1, -1],  # 1 accepted, 3 rejected, "1" sampled
        [0, 1, 2, 3, 4],  # all accepted, "4" sampled
        [-1, -1, -1, -1, -1],  # sampling skipped, use backup token "30"
114
        [0, 1, 2, -1, -1],  # explicitly discarded, sampling should be ignored
115
    ]
116
117
118
    sampled_token_ids_tensor = torch.tensor(
        sampled_token_ids, dtype=torch.int32, device=device
    )
119
    sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
120
121
122
    for i in range(len(sampled_token_ids_cpu)):
        if discarded_req_mask[i]:
            sampled_token_ids_cpu[i] = []
123
124

    expected_next_token_ids_cpu = [1, 4, 30, 40]
125
126
127
    expected_next_token_ids_tensor = torch.tensor(
        expected_next_token_ids_cpu, dtype=torch.int32, device=device
    )
128
129
130
131

    proposer = _create_proposer("eagle", num_speculative_tokens)

    next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
132
133
134
135
136
        sampled_token_ids_cpu,
        mock_requests,
        mock_input_batch,
        mock_num_scheduled_tokens,
    )
137
138
139
140
141
142
143
144
145

    assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

146
147
148
    expected_valid_sampled_tokens_count = torch.tensor(
        [2, 5, 0, 0], dtype=torch.int32, device=device
    )
149

150
    next_token_ids_from_padded, valid_sampled_tokens_count = (
151
        proposer.prepare_next_token_ids_padded(
152
153
154
155
            common_attn_metadata,
            sampled_token_ids_tensor,
            mock_requests,
            mock_input_batch,
156
            discarded_req_mask,
157
158
        )
    )
159

160
161
    assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
    assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
162
163


164
165
166
167
168
169
170
171
172
173
def test_prepare_inputs():
    """
    cu_target_query_lens: [0, a, a + b, a + b + c]
    num_rejected_tokens: [n1, n2, n3]
    num_tokens_per_req: [a - n1, b - n2, c - n3]
    cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    token_indices: [0, 1, ..., a - n1 - 1,
                    a, a + 1, ..., a + b - n2 - 1,
                    a + b, a + b + 1, ..., a + b + c - n3 - 1]
    """
174
    device = torch.device(current_platform.device_type)
175

176
    # q1 = 4, q2 = 7, q3 = 5
177
178
    # n1 = 1, n2 = 3, n3 = 2

179
180
181
182
183
184
185
186
187
188
    batch_spec = BatchSpec(
        seq_lens=[4, 7, 5],
        query_lens=[4, 7, 5],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
189

190
191
192
193
194
195
196
197
198
199
200
    # If there are `k` sampled tokens, then `k-1` tokens are draft tokens
    # from the previous iteration, and the last token is the bonus token sampled
    # from the base model.
    num_draft_tokens = [3, 6, 4]  # one less than query_lens
    # num rejected tokens is [1, 3, 2]
    ACCEPT_TOKEN = 0
    BONUS_TOKEN = 1
    REJECT_TOKEN = -1
    sampled_token_ids = [
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
        [
201
202
203
204
205
206
207
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            BONUS_TOKEN,
208
        ],
209
210
211
212
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
    ]
    sampled_token_ids = [
        [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
213
    ]
214
215
216
217
218

    # Expected calculations:
    # query_len_per_req = [4, 7, 5]
    # num_tokens_per_req = [3, 4, 3]  (after subtracting rejected tokens)
    # Expected cumulative counts: [0, 3, 7, 10]
219
220
221
    expected_cu_num_tokens = torch.tensor(
        [0, 3, 7, 10], dtype=torch.int32, device=device
    )
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    # Expected token indices (mapped from original positions):
    # First request: indices 0, 1, 2      (keeping first 3 from positions 0-3)
    # Second request: indices 4, 5, 6, 7  (keeping first 4 from positions 4-10)
    # Third request: indices 11, 12, 13   (keeping first 3 from positions 11-15)
    expected_token_indices = torch.tensor(
        [
            0,
            1,
            2,  # First request: 3 tokens (4-1)
            4,
            5,
            6,
            7,  # Second request: 4 tokens (7-3)
            11,
            12,
238
            13,  # Third request: 3 tokens (5-2)
239
240
        ],
        dtype=torch.int32,
241
242
        device=device,
    )
243
    proposer = _create_proposer("eagle", 1)
244

245
    updated_metadata, token_indices = proposer.prepare_inputs(
246
247
        common_attn_metadata, sampled_token_ids, num_draft_tokens
    )
248

249
    assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
250
251
252
253
    assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
    assert torch.equal(token_indices, expected_token_indices)


254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def test_prepare_inputs_padded():
    """
    Input scenario is 3 requests with num_speculative_tokens == 2 and:
    - Request 1: query_len = 3, rejected = 1
    - Request 2: query_len = 3, rejected = 0
    - Request 3: query_len = 3, rejected = 2

    Expected outputs:
    token_indices_to_sample: [1, 5, 6]
    Reason: After accounting for rejections, these are the valid token positions
            from the original indices to sample from.
    """

    device = torch.device(current_platform.device_type)

269
270
271
    expected_token_indices_to_sample = torch.tensor(
        [1, 5, 6], dtype=torch.int32, device=device
    )
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    num_speculative_tokens = 2
    batch_spec = BatchSpec(
        seq_lens=[3, 3, 3],
        query_lens=[3, 3, 3],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

    # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
286
287
288
    expected_query_start_loc = torch.tensor(
        [0, 3, 6, 9], dtype=torch.int32, device=device
    )
289
290
291
292
293
294
295
296
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids=[[0] * num_speculative_tokens] * 3,
        device=device,
    )

    # num_rejected_tokens = [1, 0, 2]
    # num_draft_tokens = [2, 2, 2]
    # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
297
298
299
    valid_sampled_tokens_count = torch.tensor(
        [2, 3, 1], dtype=torch.int32, device=device
    )
300
301
302

    proposer = _create_proposer("eagle", num_speculative_tokens)

303
304
    output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
        common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
305
    )
306
307

    assert output_metadata.max_query_len == 3
308
309
    assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
    assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
310
311


312
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
313
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
314
315
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
316
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
317
318
319
320
321
322
323
324
325
326
327
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_load_model(
    mock_get_model,
    mock_get_layers,
    mock_get_pp_group,
    method,
    attn_backend,
    pp_size,
    use_distinct_embed_tokens,
328
    use_distinct_lm_head,
329
330
    monkeypatch,
):
331
332
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

333
334
335
336
337
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
338

339
    if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
340
341
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

342
    # Setup draft model mock
343
    mock_model = mock.MagicMock()
344
345
    mock_model.model = mock.MagicMock()
    mock_model.has_own_embed_tokens = use_distinct_embed_tokens
346
    if use_distinct_embed_tokens:
347
348
349
350
        mock_model.model.embed_tokens = mock.MagicMock()
    mock_model.has_own_lm_head = use_distinct_lm_head
    if use_distinct_lm_head:
        mock_model.lm_head = mock.MagicMock()
351

352
    mock_get_model.return_value = mock_model
353
354
355
356

    # Setup mocks for attention layers
    target_attn_layers = {
        "target_attn_1": mock.MagicMock(),
357
        "target_attn_2": mock.MagicMock(),
358
    }
359
    target_indx_layers: dict[str, mock.MagicMock] = {}
360
    # Draft model has one extra attention layer compared to target model
361
    all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
362

363
364
    all_indx_layers: dict[str, mock.MagicMock] = {}

365
    # Make mock_get_layers return different values for each call
366
    mock_get_layers.side_effect = [
367
368
369
370
        target_attn_layers,
        target_indx_layers,
        all_attn_layers,
        all_indx_layers,
371
    ]
372

373
374
    # Setup mock for pp group to return the appropriate value for world size
    mock_pp_group = mock.MagicMock()
375
    mock_pp_group.world_size = pp_size
376
377
    mock_get_pp_group.return_value = mock_pp_group

378
    # Set up the target model mock with a custom class so that
379
380
381
382
383
384
385
    # isinstance() checks match the expected type.
    class _TargetModelStub(LlamaForCausalLM):
        model: mock.MagicMock
        lm_head: mock.MagicMock

    target_model = mock.create_autospec(_TargetModelStub, instance=True)
    target_model.model = mock.MagicMock()
386
387
    target_model.lm_head = mock.MagicMock()
    target_model.model.embed_tokens = mock.MagicMock()
388

389
    from vllm.model_executor.models import SupportsMultiModal
390

391
392
    assert not isinstance(target_model, SupportsMultiModal)

393
    # Create proposer using the helper function
394
    proposer = _create_proposer(method, num_speculative_tokens=8)
395
396
397
398
399

    # Call the method under test
    proposer.load_model(target_model)

    # Verify common interactions
400
    mock_get_model.assert_called_once()
401

402
403
404
405
406
    # Verify that the lm head is set correctly
    if use_distinct_lm_head:
        assert proposer.model.lm_head is not target_model.lm_head
    else:
        assert proposer.model.lm_head is target_model.lm_head
407
408
409
410

    # Verify that the embed tokens are set correctly
    # If pp_size is > 1, the embed tokens should be distinct
    if pp_size > 1 or use_distinct_embed_tokens:
411
        assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
412
    else:
413
        assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
414
415


416
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
417
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
418
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
419
420
421
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

422
423
424
425
426
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
427

428
429
430
431
432
    if attn_backend == "TREE_ATTN":
        pytest.skip(
            "TREE_ATTN is tested separately in test_propose_tree"
            "because it requires special input mocking."
        )
433

434
    if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
435
436
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

437
    # Use GPU device
438
    device = torch.device(current_platform.device_type)
439
440
441
442
443
444
445

    # Setup test parameters
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
446
    seq_lens = [seq_len_1, seq_len_2]
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

    # Create proposer first so we can use its actual hidden_size
    proposer = _create_proposer("eagle", num_speculative_tokens)
    # Get the hidden_size from the proposer to ensure consistency
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            logits[i, token_id] = 100.0
        return logits

    # We mock a model that returns deterministic logits
    # Sequence 1: 42, 43, 44, ...
    # Sequence 2: 60, 61, 62, ...
    base_token_ids = [42, 60]

    # Skip loading the model and replace it with a mock directly
    # Create the mock model with deterministic outputs
    model_mock = mock.MagicMock()

    # Setup for model forward calls
    forward_returns = []
    for i in range(num_speculative_tokens):
        if i == 0:
            # First call uses all tokens
            h_logits = torch.zeros(total_tokens, hidden_size, device=device)
            h_states = torch.zeros(total_tokens, hidden_size, device=device)
        else:
            # Subsequent calls use batch_size tokens
            h_logits = torch.zeros(batch_size, hidden_size, device=device)
            h_states = torch.zeros(batch_size, hidden_size, device=device)
        forward_returns.append((h_logits, h_states))

    # For single token case, we only need the first item;
    # for multi-token, we need the sequence
    if num_speculative_tokens == 1:
        model_mock.return_value = forward_returns[0]
    else:
        model_mock.side_effect = forward_returns

    # Setup for compute_logits calls
    logits_returns = []
    for i in range(num_speculative_tokens):
        # For each call, increment the base token IDs
        current_tokens = [base_id + i for base_id in base_token_ids]
        logits_returns.append(create_deterministic_logits(current_tokens))

    if num_speculative_tokens == 1:
        model_mock.compute_logits.return_value = logits_returns[0]
    else:
        model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

504
505
506
    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

507
    # Create input tensors
508
509
510
511
512
513
514
515
516
517
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
518

519
520
521
522
523
524
525
526
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
527
528
    sampling_metadata = mock.MagicMock()

529
    if attn_backend == "FLASH_ATTN":
530
531
532
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.FLASH_ATTN
        )
533
    elif attn_backend == "TRITON_ATTN":
534
535
536
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TRITON_ATTN
        )
537
    elif attn_backend == "TREE_ATTN":
538
539
540
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TREE_ATTN
        )
541
542
543
    else:
        raise ValueError(f"Unsupported attention backend: {attn_backend}")

544
545
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
546
        layer_names=proposer.attn_layer_names,
547
548
549
550
551
552
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building
    proposer.runner = mock.MagicMock()
553
    proposer.runner.attn_groups.append([mock.MagicMock()])
554
555
556
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
557
    proposer._get_attention_metadata_builder = mock.MagicMock(
558
559
        return_value=attn_metadata_builder
    )
560

561
562
563
564
565
566
567
568
569
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
570
571
572
573
574
575
576
577

    assert result.shape == (batch_size, num_speculative_tokens)

    # Create expected tokens based on our token pattern
    if num_speculative_tokens == 1:
        # Example for num_speculative_tokens=1:
        # [[42], [60]]
        expected_tokens = torch.tensor(
578
579
            [[base_token_ids[0]], [base_token_ids[1]]], device=device
        )
580
581
582
    else:
        # Example for num_speculative_tokens=3:
        # [[42, 43, 44], [60, 61, 62]]
583
584
585
        expected_tokens = torch.zeros(
            (batch_size, num_speculative_tokens), dtype=torch.int64, device=device
        )
586
587
588
589
590
591
        for i in range(batch_size):
            for j in range(num_speculative_tokens):
                expected_tokens[i, j] = base_token_ids[i] + j

    # Verify all tokens match our expectations
    assert torch.equal(result, expected_tokens)
592
593
594
595
596


@pytest.mark.parametrize(
    "spec_token_tree",
    [
597
598
599
600
601
602
        [(0,)],  # A single token
        [(0,), (0, 0), (0, 0, 0)],  # Chain
        [(0,), (1,), (2,)],  # Parallel
        [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)],  # Tree
    ],
)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def test_propose_tree(spec_token_tree):
    # Get GPU device.
    device = torch.device(current_platform.device_type)

    # Setup test parameters.
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
    seq_lens = [seq_len_1, seq_len_2]
    num_speculative_tokens = len(spec_token_tree)

    # Create proposer first so we can use its actual hidden_size.
617
618
619
    proposer = _create_proposer(
        "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
    )
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    # Get the hidden_size from the proposer to ensure consistency.
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids, k: int):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            # Assign decreasing values to the k, consecutive, tokens.
            for j in range(k):
                logits[i, token_id + j] = 100.0 - j
        return logits

    # Mock a model that returns deterministic logits.
    base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)

    # Skip loading the model and replace it with a mock that returns
    # deterministic outputs.
    model_mock = mock.MagicMock()

    # Mock the model forward calls.
640
641
642
643
644
645
    forward_returns = [
        (
            torch.zeros(total_tokens, hidden_size, device=device),
            torch.zeros(total_tokens, hidden_size, device=device),
        )
    ]
646
    for cu_num_drafts in proposer.cu_drafts_per_level:
647
648
        h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
        h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
649
650
651
652
        forward_returns.append((h_logits, h_states))
    model_mock.side_effect = forward_returns

    # Mock the compute_logits calls.
653
654
655
    cu_num_drafts_tensor = torch.tensor(
        [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
    )
656
657
658
    logits_returns = []
    for level, num_children in enumerate(proposer.child_drafts_per_level):
        token_ids = base_token_ids + cu_num_drafts_tensor[level]
659
        level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
660
661
662
        level_logits = []
        for i in range(level_num_drafts // num_children):
            level_logits.append(
663
664
                create_deterministic_logits(token_ids + i * num_children, num_children)
            )
665
666
667
668
669
670
671
672
673
674
        logits_returns.append(torch.stack(level_logits, dim=1))
    model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

    # Get the tree attention metadata builder.
675
676
677
    attn_metadata_builder_cls, _ = try_get_attention_backend(
        AttentionBackendEnum.TREE_ATTN
    )
678
679
680
681
682
683
684
685
686
687
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
        layer_names=proposer.attn_layer_names,
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building.
    proposer.runner = mock.MagicMock()
    proposer.runner.attn_groups.append([mock.MagicMock()])
688
689
690
691
    proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
692
    proposer._get_attention_metadata_builder = mock.MagicMock(
693
694
        return_value=attn_metadata_builder
    )
695
696

    # Setup inputs for the proposer.
697
698
699
700
701
702
703
704
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
705
706
707
708
709
710
711
712
713
714
715
716
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )
    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
    sampling_metadata = mock.MagicMock()

    # Propose draft tokens.
717
718
719
720
721
722
723
724
725
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
726
727
728
729
730
    assert result.shape == (batch_size, num_speculative_tokens)

    # The tokens are expected to be consecutive integers starting
    # from the base token IDs.
    expected_tokens = base_token_ids[:, None] + torch.arange(
731
732
        num_speculative_tokens, dtype=torch.int64, device=device
    )
733
734
735

    # Verify that the draft tokens match our expectations.
    assert torch.equal(result, expected_tokens)