"vllm/entrypoints/openai/completion/protocol.py" did not exist on "6f1e7f7226447f606a0731376a2d0bd080aa2767"
test_dynamic_spec_decode.py 3.73 KB
Newer Older
1
from unittest.mock import MagicMock, patch
2
3
4
5
6
7
8
9
10
11

import pytest
import torch

from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

12
from .test_utils import mock_spec_decode_sampler
13
14
15
from .utils import create_batch, mock_worker


16
17
18
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
19
20
@pytest.mark.parametrize("acceptance_sampler_method",
                         ["rejection_sampler", "typical_acceptance_sampler"])
21
@torch.inference_mode()
22
23
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
                             acceptance_sampler_method: str):
24
25
26
27
28
29
30
31
32
    """Verify that speculative tokens are disabled when the batch size
    exceeds the threshold.
    """
    disable_by_batch_size = 3
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)
    worker = SpecDecodeWorker(proposer_worker=draft_worker,
                              scorer_worker=target_worker,
33
34
                              spec_decode_sampler=mock_spec_decode_sampler(
                                  acceptance_sampler_method),
35
                              disable_logprobs=False,
36
37
38
39
40
41
42
43
44
45
46
47
                              metrics_collector=metrics_collector,
                              disable_by_batch_size=disable_by_batch_size)

    exception_secret = 'artificial stop'
    draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

    seq_group_metadata_list, _, _ = create_batch(batch_size, k)
    execute_model_req = ExecuteModelRequest(
        seq_group_metadata_list=seq_group_metadata_list,
        num_lookahead_slots=k,
        running_queue_size=queue_size)

48
49
50
51
52
53
    if queue_size > disable_by_batch_size:
        with patch.object(worker,
                          '_run_no_spec',
                          side_effect=ValueError(exception_secret)), \
            pytest.raises(ValueError, match=exception_secret):
            worker.execute_model(execute_model_req=execute_model_req)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    # When the batch size is larger than the threshold,
    # we expect no speculative tokens (0).
    expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
    assert seq_group_metadata_list[
        0].num_speculative_tokens == expected_num_spec_tokens

    draft_worker.sampler_output.side_effect = ValueError(exception_secret)

    proposer = Top1Proposer(
        worker=draft_worker,
        device='cpu',  # not used
        vocab_size=100,  # not used
        # Must be long enough to avoid being skipped due to length.
        max_proposal_len=1024,
    )

    if queue_size < disable_by_batch_size:
        # Should raise exception when executing the mocked draft model.
        with pytest.raises(ValueError, match=exception_secret):
74
75
76
77
78
            proposer.get_spec_proposals(
                execute_model_req=ExecuteModelRequest(
                    seq_group_metadata_list=seq_group_metadata_list,
                    num_lookahead_slots=k),
                seq_ids_with_bonus_token_in_last_step=set())
79
80
81
    else:
        # Should not execute the draft model because spec decode is disabled
        # for all requests. Accordingly, the proposal length should be 0.
82
        proposals = proposer.get_spec_proposals(
83
84
            execute_model_req=ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
85
86
                num_lookahead_slots=k),
            seq_ids_with_bonus_token_in_last_step=set())
87
        assert proposals.proposal_lens.tolist() == [0] * batch_size