test_eagle.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8

from unittest import mock

import pytest
import torch

9
from tests.utils import get_attn_backend_list_based_on_platform
10
11
12
13
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
14
    try_get_attention_backend,
15
)
16
from vllm.attention.backends.registry import AttentionBackendEnum
17
18
19
20
21
22
23
24
25
from vllm.config import (
    CacheConfig,
    DeviceConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    SpeculativeConfig,
    VllmConfig,
)
26
from vllm.config.load import LoadConfig
27
from vllm.model_executor.models.llama import LlamaForCausalLM
28
from vllm.platforms import current_platform
29
from vllm.v1.spec_decode.eagle import EagleProposer
30
31
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
32
33
34
35
36
37

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


38
39
40
def _create_proposer(
    method: str,
    num_speculative_tokens: int,
41
    speculative_token_tree: list[tuple[int, ...]] | None = None,
42
) -> EagleProposer:
43
    model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
44
45
46
47

    # Choose model directory based on method
    draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

48
49
50
51
52
    spec_token_tree_str = None
    if speculative_token_tree is not None:
        assert num_speculative_tokens == len(speculative_token_tree)
        spec_token_tree_str = str(speculative_token_tree)

53
54
55
56
57
    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        model=draft_model_dir,
        method=method,
58
59
        num_speculative_tokens=num_speculative_tokens,
        speculative_token_tree=spec_token_tree_str,
60
61
    )

62
63
64
65
66
67
68
    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(),
        speculative_config=speculative_config,
        device_config=DeviceConfig(device=current_platform.device_type),
        parallel_config=ParallelConfig(),
        load_config=LoadConfig(),
69
70
71
72
        scheduler_config=SchedulerConfig(
            max_model_len=model_config.max_model_len,
            is_encoder_decoder=model_config.is_encoder_decoder,
        ),
73
    )
74

75
    return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
76
77


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def test_prepare_next_token_ids():
    """
    Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
    Each will produce a device tensor of next_token_ids, taking as input
    either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
    or the CPU python list[list[int]] with the rejected tokens removed.
    """
    device = torch.device(current_platform.device_type)

    num_requests = 4
    num_speculative_tokens = 4
    batch_spec = BatchSpec(
        seq_lens=[num_speculative_tokens + 1] * num_requests,
        query_lens=[num_speculative_tokens + 1] * num_requests,
    )

94
    req_ids = [f"req_{i + 1}" for i in range(num_requests)]
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    mock_input_batch = mock.MagicMock(spec=InputBatch)
    mock_input_batch.req_ids = req_ids
    mock_input_batch.num_reqs = num_requests
    mock_input_batch.vocab_size = 100

    mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
    mock_requests = {}
    for req_id in req_ids:
        mock_request = mock.MagicMock(spec=CachedRequestState)
        # Each request will have a backup next token id of 10, 20, 30, 40
        mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
        mock_request.num_computed_tokens = 0
        mock_requests[req_id] = mock_request

109
110
111
112
    # explicitly discard the last request
    discarded_req_mask = torch.tensor(
        [False, False, False, True], dtype=torch.bool, device=device
    )
113
114
115
116
    sampled_token_ids = [
        [0, 1, -1, -1, -1],  # 1 accepted, 3 rejected, "1" sampled
        [0, 1, 2, 3, 4],  # all accepted, "4" sampled
        [-1, -1, -1, -1, -1],  # sampling skipped, use backup token "30"
117
        [0, 1, 2, -1, -1],  # explicitly discarded, sampling should be ignored
118
    ]
119
120
121
    sampled_token_ids_tensor = torch.tensor(
        sampled_token_ids, dtype=torch.int32, device=device
    )
122
    sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
123
124
125
    for i in range(len(sampled_token_ids_cpu)):
        if discarded_req_mask[i]:
            sampled_token_ids_cpu[i] = []
126
127

    expected_next_token_ids_cpu = [1, 4, 30, 40]
128
129
130
    expected_next_token_ids_tensor = torch.tensor(
        expected_next_token_ids_cpu, dtype=torch.int32, device=device
    )
131
132
133
134

    proposer = _create_proposer("eagle", num_speculative_tokens)

    next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
135
136
137
138
139
        sampled_token_ids_cpu,
        mock_requests,
        mock_input_batch,
        mock_num_scheduled_tokens,
    )
140
141
142
143
144
145
146
147
148

    assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

149
150
151
    expected_valid_sampled_tokens_count = torch.tensor(
        [2, 5, 0, 0], dtype=torch.int32, device=device
    )
152

153
    next_token_ids_from_padded, valid_sampled_tokens_count = (
154
        proposer.prepare_next_token_ids_padded(
155
156
157
158
            common_attn_metadata,
            sampled_token_ids_tensor,
            mock_requests,
            mock_input_batch,
159
            discarded_req_mask,
160
161
        )
    )
162

163
164
    assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
    assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
165
166


167
168
169
170
171
172
173
174
175
176
def test_prepare_inputs():
    """
    cu_target_query_lens: [0, a, a + b, a + b + c]
    num_rejected_tokens: [n1, n2, n3]
    num_tokens_per_req: [a - n1, b - n2, c - n3]
    cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    token_indices: [0, 1, ..., a - n1 - 1,
                    a, a + 1, ..., a + b - n2 - 1,
                    a + b, a + b + 1, ..., a + b + c - n3 - 1]
    """
177
    device = torch.device(current_platform.device_type)
178

179
    # q1 = 4, q2 = 7, q3 = 5
180
181
    # n1 = 1, n2 = 3, n3 = 2

182
183
184
185
186
187
188
189
190
191
    batch_spec = BatchSpec(
        seq_lens=[4, 7, 5],
        query_lens=[4, 7, 5],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
192

193
194
195
196
197
198
199
200
201
202
203
    # If there are `k` sampled tokens, then `k-1` tokens are draft tokens
    # from the previous iteration, and the last token is the bonus token sampled
    # from the base model.
    num_draft_tokens = [3, 6, 4]  # one less than query_lens
    # num rejected tokens is [1, 3, 2]
    ACCEPT_TOKEN = 0
    BONUS_TOKEN = 1
    REJECT_TOKEN = -1
    sampled_token_ids = [
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
        [
204
205
206
207
208
209
210
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            ACCEPT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            REJECT_TOKEN,
            BONUS_TOKEN,
211
        ],
212
213
214
215
        [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
    ]
    sampled_token_ids = [
        [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
216
    ]
217
218
219
220
221

    # Expected calculations:
    # query_len_per_req = [4, 7, 5]
    # num_tokens_per_req = [3, 4, 3]  (after subtracting rejected tokens)
    # Expected cumulative counts: [0, 3, 7, 10]
222
223
224
    expected_cu_num_tokens = torch.tensor(
        [0, 3, 7, 10], dtype=torch.int32, device=device
    )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    # Expected token indices (mapped from original positions):
    # First request: indices 0, 1, 2      (keeping first 3 from positions 0-3)
    # Second request: indices 4, 5, 6, 7  (keeping first 4 from positions 4-10)
    # Third request: indices 11, 12, 13   (keeping first 3 from positions 11-15)
    expected_token_indices = torch.tensor(
        [
            0,
            1,
            2,  # First request: 3 tokens (4-1)
            4,
            5,
            6,
            7,  # Second request: 4 tokens (7-3)
            11,
            12,
241
            13,  # Third request: 3 tokens (5-2)
242
243
        ],
        dtype=torch.int32,
244
245
        device=device,
    )
246
    proposer = _create_proposer("eagle", 1)
247

248
    updated_metadata, token_indices = proposer.prepare_inputs(
249
250
        common_attn_metadata, sampled_token_ids, num_draft_tokens
    )
251

252
    assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
253
254
255
256
    assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
    assert torch.equal(token_indices, expected_token_indices)


257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def test_prepare_inputs_padded():
    """
    Input scenario is 3 requests with num_speculative_tokens == 2 and:
    - Request 1: query_len = 3, rejected = 1
    - Request 2: query_len = 3, rejected = 0
    - Request 3: query_len = 3, rejected = 2

    Expected outputs:
    token_indices_to_sample: [1, 5, 6]
    Reason: After accounting for rejections, these are the valid token positions
            from the original indices to sample from.
    """

    device = torch.device(current_platform.device_type)

272
273
274
    expected_token_indices_to_sample = torch.tensor(
        [1, 5, 6], dtype=torch.int32, device=device
    )
275
276
277
278
279
280
281
282
283
284
285
286
287
288

    num_speculative_tokens = 2
    batch_spec = BatchSpec(
        seq_lens=[3, 3, 3],
        query_lens=[3, 3, 3],
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )

    # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
289
290
291
    expected_query_start_loc = torch.tensor(
        [0, 3, 6, 9], dtype=torch.int32, device=device
    )
292
293
294
295
296
297
298
299
    spec_decode_metadata = SpecDecodeMetadata.make_dummy(
        draft_token_ids=[[0] * num_speculative_tokens] * 3,
        device=device,
    )

    # num_rejected_tokens = [1, 0, 2]
    # num_draft_tokens = [2, 2, 2]
    # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
300
301
302
    valid_sampled_tokens_count = torch.tensor(
        [2, 3, 1], dtype=torch.int32, device=device
    )
303
304
305

    proposer = _create_proposer("eagle", num_speculative_tokens)

306
307
    output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
        common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
308
    )
309
310

    assert output_metadata.max_query_len == 3
311
312
    assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
    assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
313
314


315
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
316
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
317
318
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
319
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
320
321
322
323
324
325
326
327
328
329
330
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_load_model(
    mock_get_model,
    mock_get_layers,
    mock_get_pp_group,
    method,
    attn_backend,
    pp_size,
    use_distinct_embed_tokens,
331
    use_distinct_lm_head,
332
333
    monkeypatch,
):
334
335
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

336
337
338
339
340
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
341

342
    if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
343
344
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

345
    # Setup draft model mock
346
    mock_model = mock.MagicMock()
347
348
    mock_model.model = mock.MagicMock()
    mock_model.has_own_embed_tokens = use_distinct_embed_tokens
349
    if use_distinct_embed_tokens:
350
351
352
353
        mock_model.model.embed_tokens = mock.MagicMock()
    mock_model.has_own_lm_head = use_distinct_lm_head
    if use_distinct_lm_head:
        mock_model.lm_head = mock.MagicMock()
354

355
    mock_get_model.return_value = mock_model
356
357
358
359

    # Setup mocks for attention layers
    target_attn_layers = {
        "target_attn_1": mock.MagicMock(),
360
        "target_attn_2": mock.MagicMock(),
361
    }
362
    target_indx_layers: dict[str, mock.MagicMock] = {}
363
    # Draft model has one extra attention layer compared to target model
364
    all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
365

366
    all_indx_layers: dict[str, mock.MagicMock] = {}
367

368
369
    all_indx_layers: dict[str, mock.MagicMock] = {}

370
    # Make mock_get_layers return different values for each call
371
    mock_get_layers.side_effect = [
372
373
374
375
        target_attn_layers,
        target_indx_layers,
        all_attn_layers,
        all_indx_layers,
376
    ]
377

378
379
    # Setup mock for pp group to return the appropriate value for world size
    mock_pp_group = mock.MagicMock()
380
    mock_pp_group.world_size = pp_size
381
382
    mock_get_pp_group.return_value = mock_pp_group

383
    # Set up the target model mock with a custom class so that
384
385
386
387
388
389
390
    # isinstance() checks match the expected type.
    class _TargetModelStub(LlamaForCausalLM):
        model: mock.MagicMock
        lm_head: mock.MagicMock

    target_model = mock.create_autospec(_TargetModelStub, instance=True)
    target_model.model = mock.MagicMock()
391
392
    target_model.lm_head = mock.MagicMock()
    target_model.model.embed_tokens = mock.MagicMock()
393

394
395
396
    from vllm.model_executor.models import SupportsMultiModal

    assert not isinstance(target_model, SupportsMultiModal)
397
398

    # Create proposer using the helper function
399
    proposer = _create_proposer(method, num_speculative_tokens=8)
400
401
402
403
404

    # Call the method under test
    proposer.load_model(target_model)

    # Verify common interactions
405
    mock_get_model.assert_called_once()
406

407
408
409
410
411
    # Verify that the lm head is set correctly
    if use_distinct_lm_head:
        assert proposer.model.lm_head is not target_model.lm_head
    else:
        assert proposer.model.lm_head is target_model.lm_head
412
413
414
415

    # Verify that the embed tokens are set correctly
    # If pp_size is > 1, the embed tokens should be distinct
    if pp_size > 1 or use_distinct_embed_tokens:
416
        assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
417
    else:
418
        assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
419
420


421
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
422
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
423
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
424
425
426
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
    monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

427
428
429
430
431
    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
432

433
434
435
436
437
    if attn_backend == "TREE_ATTN":
        pytest.skip(
            "TREE_ATTN is tested separately in test_propose_tree"
            "because it requires special input mocking."
        )
438

439
    if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
440
441
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

442
    # Use GPU device
443
    device = torch.device(current_platform.device_type)
444
445
446
447
448
449
450

    # Setup test parameters
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
451
    seq_lens = [seq_len_1, seq_len_2]
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

    # Create proposer first so we can use its actual hidden_size
    proposer = _create_proposer("eagle", num_speculative_tokens)
    # Get the hidden_size from the proposer to ensure consistency
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            logits[i, token_id] = 100.0
        return logits

    # We mock a model that returns deterministic logits
    # Sequence 1: 42, 43, 44, ...
    # Sequence 2: 60, 61, 62, ...
    base_token_ids = [42, 60]

    # Skip loading the model and replace it with a mock directly
    # Create the mock model with deterministic outputs
    model_mock = mock.MagicMock()

    # Setup for model forward calls
    forward_returns = []
    for i in range(num_speculative_tokens):
        if i == 0:
            # First call uses all tokens
            h_logits = torch.zeros(total_tokens, hidden_size, device=device)
            h_states = torch.zeros(total_tokens, hidden_size, device=device)
        else:
            # Subsequent calls use batch_size tokens
            h_logits = torch.zeros(batch_size, hidden_size, device=device)
            h_states = torch.zeros(batch_size, hidden_size, device=device)
        forward_returns.append((h_logits, h_states))

    # For single token case, we only need the first item;
    # for multi-token, we need the sequence
    if num_speculative_tokens == 1:
        model_mock.return_value = forward_returns[0]
    else:
        model_mock.side_effect = forward_returns

    # Setup for compute_logits calls
    logits_returns = []
    for i in range(num_speculative_tokens):
        # For each call, increment the base token IDs
        current_tokens = [base_id + i for base_id in base_token_ids]
        logits_returns.append(create_deterministic_logits(current_tokens))

    if num_speculative_tokens == 1:
        model_mock.compute_logits.return_value = logits_returns[0]
    else:
        model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

509
510
511
    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

512
    # Create input tensors
513
514
515
516
517
518
519
520
521
522
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
523

524
525
526
527
528
529
530
531
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
532
533
    sampling_metadata = mock.MagicMock()

534
    if attn_backend == "FLASH_ATTN":
535
536
537
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.FLASH_ATTN
        )
538
    elif attn_backend == "TRITON_ATTN":
539
540
541
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TRITON_ATTN
        )
542
    elif attn_backend == "TREE_ATTN":
543
544
545
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.TREE_ATTN
        )
546
547
548
549
    elif attn_backend == "ROCM_AITER_FA":
        attn_metadata_builder_cls, _ = try_get_attention_backend(
            AttentionBackendEnum.ROCM_AITER_FA
        )
550
551
552
    else:
        raise ValueError(f"Unsupported attention backend: {attn_backend}")

553
554
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
555
        layer_names=proposer.attn_layer_names,
556
557
558
559
560
561
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building
    proposer.runner = mock.MagicMock()
562
    proposer.runner.attn_groups.append([mock.MagicMock()])
563
564
565
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
566
    proposer._get_attention_metadata_builder = mock.MagicMock(
567
568
        return_value=attn_metadata_builder
    )
569

570
571
572
573
574
575
576
577
578
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
579
580
581
582
583
584
585
586

    assert result.shape == (batch_size, num_speculative_tokens)

    # Create expected tokens based on our token pattern
    if num_speculative_tokens == 1:
        # Example for num_speculative_tokens=1:
        # [[42], [60]]
        expected_tokens = torch.tensor(
587
588
            [[base_token_ids[0]], [base_token_ids[1]]], device=device
        )
589
590
591
    else:
        # Example for num_speculative_tokens=3:
        # [[42, 43, 44], [60, 61, 62]]
592
593
594
        expected_tokens = torch.zeros(
            (batch_size, num_speculative_tokens), dtype=torch.int64, device=device
        )
595
596
597
598
599
600
        for i in range(batch_size):
            for j in range(num_speculative_tokens):
                expected_tokens[i, j] = base_token_ids[i] + j

    # Verify all tokens match our expectations
    assert torch.equal(result, expected_tokens)
601
602
603
604
605


@pytest.mark.parametrize(
    "spec_token_tree",
    [
606
607
608
609
610
611
        [(0,)],  # A single token
        [(0,), (0, 0), (0, 0, 0)],  # Chain
        [(0,), (1,), (2,)],  # Parallel
        [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)],  # Tree
    ],
)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def test_propose_tree(spec_token_tree):
    # Get GPU device.
    device = torch.device(current_platform.device_type)

    # Setup test parameters.
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100
    seq_lens = [seq_len_1, seq_len_2]
    num_speculative_tokens = len(spec_token_tree)

    # Create proposer first so we can use its actual hidden_size.
626
627
628
    proposer = _create_proposer(
        "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
    )
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    # Get the hidden_size from the proposer to ensure consistency.
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids, k: int):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            # Assign decreasing values to the k, consecutive, tokens.
            for j in range(k):
                logits[i, token_id + j] = 100.0 - j
        return logits

    # Mock a model that returns deterministic logits.
    base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)

    # Skip loading the model and replace it with a mock that returns
    # deterministic outputs.
    model_mock = mock.MagicMock()

    # Mock the model forward calls.
649
650
651
652
653
654
    forward_returns = [
        (
            torch.zeros(total_tokens, hidden_size, device=device),
            torch.zeros(total_tokens, hidden_size, device=device),
        )
    ]
655
    for cu_num_drafts in proposer.cu_drafts_per_level:
656
657
        h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
        h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
658
659
660
661
        forward_returns.append((h_logits, h_states))
    model_mock.side_effect = forward_returns

    # Mock the compute_logits calls.
662
663
664
    cu_num_drafts_tensor = torch.tensor(
        [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
    )
665
666
667
    logits_returns = []
    for level, num_children in enumerate(proposer.child_drafts_per_level):
        token_ids = base_token_ids + cu_num_drafts_tensor[level]
668
        level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
669
670
671
        level_logits = []
        for i in range(level_num_drafts // num_children):
            level_logits.append(
672
673
                create_deterministic_logits(token_ids + i * num_children, num_children)
            )
674
675
676
677
678
679
680
681
682
683
        logits_returns.append(torch.stack(level_logits, dim=1))
    model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

    # Assign draft attn_layer_names since load_model is not invoked
    proposer.attn_layer_names = ["layer.0"]

    # Get the tree attention metadata builder.
684
685
686
    attn_metadata_builder_cls, _ = try_get_attention_backend(
        AttentionBackendEnum.TREE_ATTN
    )
687
688
689
690
691
692
693
694
695
696
    attn_metadata_builder = attn_metadata_builder_cls(
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
        layer_names=proposer.attn_layer_names,
        vllm_config=proposer.vllm_config,
        device=device,
    )

    # Mock runner for attention metadata building.
    proposer.runner = mock.MagicMock()
    proposer.runner.attn_groups.append([mock.MagicMock()])
697
698
699
700
    proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
    proposer.runner.attn_groups[0][
        0
    ].get_metadata_builder.return_value = attn_metadata_builder
701
    proposer._get_attention_metadata_builder = mock.MagicMock(
702
703
        return_value=attn_metadata_builder
    )
704
705

    # Setup inputs for the proposer.
706
707
708
709
710
711
712
713
    target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
    target_positions = torch.cat(
        [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
    )
    target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
    next_token_ids = torch.randint(
        0, vocab_size, (batch_size,), dtype=torch.int32, device=device
    )
714
715
716
717
718
719
720
721
722
723
724
725
    batch_spec = BatchSpec(
        seq_lens=seq_lens,
        query_lens=seq_lens,
    )
    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        block_size=16,
        device=device,
    )
    sampling_metadata = mock.MagicMock()

    # Propose draft tokens.
726
727
728
729
730
731
732
733
734
    result = proposer.propose(
        target_token_ids=target_token_ids,
        target_positions=target_positions,
        target_hidden_states=target_hidden_states,
        next_token_ids=next_token_ids,
        last_token_indices=None,
        common_attn_metadata=common_attn_metadata,
        sampling_metadata=sampling_metadata,
    )
735
736
737
738
739
    assert result.shape == (batch_size, num_speculative_tokens)

    # The tokens are expected to be consecutive integers starting
    # from the base token IDs.
    expected_tokens = base_token_ids[:, None] + torch.arange(
740
741
        num_speculative_tokens, dtype=torch.int64, device=device
    )
742
743
744

    # Verify that the draft tokens match our expectations.
    assert torch.equal(result, expected_tokens)