test_ngram_worker.py 7.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import torch
5
import os
6

7
from vllm.sequence import ExecuteModelRequest
8
9
10
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

11
from .utils import create_seq_group_metadata_from_prompts, create_worker
12
from ..utils import models_path_prefix
13
14
15
16
17
18
19
20
21
22


def test_ngram_algo_correctness_for_single_no_match():
    """Verify our ngram algo find the right candidate in the prompt

    For the scenario cannot find any candidate in one single batch
    """
    block_size = 32
    num_gpu_blocks = 2048 // block_size
    seed = 100
23
    model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    vocab_size = 32_000
    device = 'cuda:0'

    ngram_worker = create_worker(
        NGramWorker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )

    proposer = Top1Proposer(
        worker=ngram_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=20,
    )

42
43
    # set ngram window [1, 3], which is window=1/2/3
    ngram_worker.set_ngram_window_size(1, 3)
44
45
46
47
48
49
50

    prompts = [
        # shall find no candidate
        [1, 2, 3, 4, 5, 6, 7],
    ]

    proposal_len = 5
51
    final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
52
53
54
55
56
57
    seq_group_metadata_list = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens=final_prompt_lens)

58
59
60
    proposals = proposer.get_spec_proposals(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
61
62
            num_lookahead_slots=proposal_len),
        seq_ids_with_bonus_token_in_last_step=None)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
    assert proposals.proposal_lens.shape == torch.Size([1])
    assert proposals.proposal_lens.tolist() == [0]


def test_ngram_algo_correctness_for_batches_not_match_all():
    """Verify our ngram algo find the right candidate in the prompt

    For the scenario find some candidate not full in batchs
    """
    block_size = 32
    num_gpu_blocks = 2048 // block_size
    seed = 100
81
    model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    vocab_size = 32_000
    device = 'cuda:0'

    ngram_worker = create_worker(
        NGramWorker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )

    proposer = Top1Proposer(
        worker=ngram_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=20,
    )

100
101
    # set ngram window [1, 3], which is window=1/2/3
    ngram_worker.set_ngram_window_size(1, 3)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    prompts = [
        # shall find no candidate
        [1, 2, 3, 4, 5, 6, 7],
        # shall find candidate 12,13,14,15,16
        [11, 12, 13, 14, 15, 16, 11],
        # shall find candidate 23,24,25,26,21
        [21, 21, 22, 23, 24, 25, 26, 21, 22],
        # shall find candidate 34,35,36,37,38
        [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
        # shall find no candidate as exceed max_proposal_len
        [
            31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
            38, 31, 32, 33
        ],
    ]

    proposal_len = 5
120
    final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
121
122
123
124
125
    seq_group_metadata_list = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens=final_prompt_lens)
126
127
    for sg in seq_group_metadata_list:
        sg.is_prompt = False
128
129
130
    proposals = proposer.get_spec_proposals(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
131
132
            num_lookahead_slots=proposal_len),
        seq_ids_with_bonus_token_in_last_step=None)
133
134
135
136
137
138
139
140

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
    assert proposals.proposal_lens.shape == torch.Size([5])

141
    # the first sequence has no match so proposal_len should be overwritten to 0
142
    assert proposals.proposal_lens.tolist(
143
    ) == [0] + [proposal_len for _ in range(3)] + [0]
144
145

    for i in range(proposal_len):
146
        assert proposals.proposal_token_ids[0][i] == -1
147
148
149
150
151
152
153
154
155
        assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
        assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
        assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
        assert proposals.proposal_token_ids[4][i] == -1


def test_ngram_algo_correctness_for_batches_match_all():
    """Verify our ngram algo find the right candidate in the prompt

156
    For the scenario find candidate in all batches
157
158
159
160
161
    """

    block_size = 32
    num_gpu_blocks = 2048 // block_size
    seed = 100
162
    model_name = os.path.join(models_path_prefix, 'JackFram/llama-68m')
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    vocab_size = 32_000
    device = 'cuda:0'

    ngram_worker = create_worker(
        NGramWorker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )

    proposer = Top1Proposer(
        worker=ngram_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=20,
    )

181
182
    # set ngram window [0, 3], which is window=1/2/3
    ngram_worker.set_ngram_window_size(1, 3)
183
184
185
186
187
188
189
190
191
192
193

    prompts = [
        # shall find candidate 12,13,14,15,16
        [11, 12, 13, 14, 15, 16, 11],
        # shall find candidate 23,24,25,26,21
        [21, 21, 22, 23, 24, 25, 26, 21, 22],
        # shall find candidate 34,35,36,37,38
        [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
    ]

    proposal_len = 5
194
    final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
195
196
197
198
199
200
    seq_group_metadata_list = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens=final_prompt_lens)

201
202
203
204
    # Normally drafter is run on decode requests only; here we check the output
    # of the ngram worker as it is the sole proposer that has no forward.
    for sg in seq_group_metadata_list:
        sg.is_prompt = False
205
206
207
    proposals = proposer.get_spec_proposals(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list,
208
209
            num_lookahead_slots=proposal_len),
        seq_ids_with_bonus_token_in_last_step=None)
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
    assert proposals.proposal_lens.shape == torch.Size([3])

    assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]

    for i in range(proposal_len):
        assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
        assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
        assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]